"""Functions to plot M/EEG data e.g. topographies."""

# Authors: The MNE-Python contributors.
# License: BSD-3-Clause
# Copyright the MNE-Python contributors.

import copy
import itertools
import warnings
from functools import partial
from numbers import Integral

import numpy as np
from scipy.interpolate import (
    CloughTocher2DInterpolator,
    LinearNDInterpolator,
    NearestNDInterpolator,
)
from scipy.sparse import csr_array
from scipy.spatial import Delaunay, Voronoi
from scipy.spatial.distance import pdist, squareform

from .._fiff.constants import FIFF
from .._fiff.meas_info import Info, _simplify_info
from .._fiff.pick import (
    _MEG_CH_TYPES_SPLIT,
    _pick_data_channels,
    _picks_by_type,
    _picks_to_idx,
    pick_channels,
    pick_info,
    pick_types,
)
from ..baseline import rescale
from ..defaults import (
    _BORDER_DEFAULT,
    _EXTRAPOLATE_DEFAULT,
    _INTERPOLATION_DEFAULT,
    _handle_default,
)
from ..transforms import apply_trans, invert_transform
from ..utils import (
    _check_option,
    _check_sphere,
    _clean_names,
    _is_numeric,
    _time_mask,
    _validate_type,
    check_version,
    fill_doc,
    legacy,
    logger,
    verbose,
    warn,
)
from ..utils.spectrum import _split_psd_kwargs
from .ui_events import TimeChange, publish, subscribe
from .utils import (
    DraggableColorbar,
    _check_delayed_ssp,
    _check_time_unit,
    _check_type_projs,
    _draw_proj_checkbox,
    _format_units_psd,
    _get_cmap,
    _get_plot_ch_type,
    _prepare_sensor_names,
    _prepare_trellis,
    _process_times,
    _set_3d_axes_equal,
    _setup_cmap,
    _setup_vmin_vmax,
    _validate_if_list_of_axes,
    figure_nobar,
    plot_sensors,
    plt_show,
)

_fnirs_types = ("hbo", "hbr", "fnirs_cw_amplitude", "fnirs_od")
_opm_coils = (FIFF.FIFFV_COIL_QUSPIN_ZFOPM_MAG, FIFF.FIFFV_COIL_QUSPIN_ZFOPM_MAG2)


# 3.8+ uses a single Collection artist rather than .collections
# https://github.com/matplotlib/matplotlib/pull/25247
def _cont_collections(cont):
    return (cont,) if check_version("matplotlib", "3.8") else tuple(cont.collections)


def _adjust_meg_sphere(sphere, info, ch_type):
    sphere = _check_sphere(sphere, info)
    assert ch_type is not None
    if ch_type in ("mag", "grad", "planar1", "planar2"):
        # move sphere X/Y (head coords) to device X/Y space
        if info["dev_head_t"] is not None:
            head_dev_t = invert_transform(info["dev_head_t"])
            sphere[:3] = apply_trans(head_dev_t, sphere[:3])
            # Set the sphere Z=0 because all this really affects is flattening.
            # We could make the head size change as a function of depth in
            # the helmet like:
            #
            #     sphere[2] /= -5
            #
            # but let's just assume some orthographic rather than parallel
            # projection for explicitness / simplicity.
            sphere[2] = 0.0
        clip_origin = (0.0, 0.0)
    else:
        clip_origin = sphere[:2].copy()
    return sphere, clip_origin


def _prepare_topomap_plot(inst, ch_type, sphere=None):
    """Prepare topo plot."""
    from ..channels.layout import _find_topomap_coords, _pair_grad_sensors, find_layout

    info = copy.deepcopy(inst if isinstance(inst, Info) else inst.info)
    sphere, clip_origin = _adjust_meg_sphere(sphere, info, ch_type)

    clean_ch_names = _clean_names(info["ch_names"])
    for ii, this_ch in enumerate(info["chs"]):
        this_ch["ch_name"] = clean_ch_names[ii]
    for comp in info["comps"]:
        comp["data"]["col_names"] = _clean_names(comp["data"]["col_names"])
    info._update_redundant()
    info["bads"] = _clean_names(info["bads"])
    info._check_consistency()

    if any(ch["coil_type"] in _opm_coils for ch in info["chs"]):
        modality = "opm"
    elif ch_type in _fnirs_types:
        modality = "fnirs"
    else:
        modality = "other"

    # special case for merging grad channels
    layout = find_layout(info)
    if (
        ch_type == "grad"
        and layout is not None
        and (
            layout.kind.startswith("Vectorview")
            or layout.kind.startswith("Neuromag_122")
        )
    ):
        picks, _ = _pair_grad_sensors(info, layout)
        pos = _find_topomap_coords(info, picks[::2], sphere=sphere)
        merge_channels = True
    elif modality != "other":
        picks, pos, merge_channels, overlapping_channels = _find_overlaps(
            info, ch_type, sphere, modality=modality
        )
    else:
        merge_channels = False
        if ch_type == "eeg":
            picks = pick_types(info, meg=False, eeg=True, ref_meg=False, exclude="bads")
        elif ch_type == "csd":
            picks = pick_types(info, meg=False, csd=True, ref_meg=False, exclude="bads")
        elif ch_type == "dbs":
            picks = pick_types(info, meg=False, dbs=True, ref_meg=False, exclude="bads")
        elif ch_type == "seeg":
            picks = pick_types(
                info, meg=False, seeg=True, ref_meg=False, exclude="bads"
            )
        else:
            picks = pick_types(info, meg=ch_type, ref_meg=False, exclude="bads")

        if len(picks) == 0:
            raise ValueError(f"No channels of type {ch_type!r}")

        pos = _find_topomap_coords(info, picks, sphere=sphere)

    ch_names = [info["ch_names"][k] for k in picks]
    if modality == "fnirs":
        # Remove the chroma label type for cleaner labeling.
        ch_names = [k[:-4] for k in ch_names]

    if merge_channels:
        if ch_type == "grad":
            # change names so that vectorview combined grads appear as MEG014x
            # instead of MEG0142 or MEG0143 which are the 2 planar grads.
            ch_names = [ch_names[k][:-1] + "x" for k in range(0, len(ch_names), 2)]
        elif modality == "fnirs":
            # Modify the channel names to indicate they are to be merged
            # New names will have the form  S1_D1xS2_D2
            # More than two channels can overlap and be merged
            for set_ in overlapping_channels:
                idx = ch_names.index(set_[0][:-4])
                new_name = "x".join(s[:-4] for s in set_)
                ch_names[idx] = new_name
        elif modality == "opm":
            # indicate that non-radial changes are to be removed
            for set_ in overlapping_channels:
                for set_ch in set_[1:]:
                    idx = ch_names.index(set_ch)
                    new_name = set_ch + "_MERGE-REMOVE"
                    ch_names[idx] = new_name

    pos = np.array(pos)[:, :2]  # 2D plot, otherwise interpolation bugs
    return picks, pos, merge_channels, ch_names, ch_type, sphere, clip_origin


def _find_overlaps(info, ch_type, sphere, modality="fnirs"):
    """Find overlapping channels."""
    from ..channels.layout import _find_topomap_coords

    if modality == "fnirs":
        picks = pick_types(
            info, meg=False, ref_meg=False, fnirs=ch_type, exclude="bads"
        )
    elif modality == "opm":
        picks = pick_types(info, meg=True, ref_meg=False, exclude="bads")
    else:
        raise ValueError(f"Invalid modality for colocated sensors: {modality}")
    chs = [info["chs"][i] for i in picks]
    locs3d = np.array([ch["loc"][:3] for ch in chs])
    dist = pdist(locs3d)

    # Store the sets of channels to be merged
    overlapping_channels = list()
    # Channels to be excluded from picks, as will be removed after merging
    channels_to_exclude = list()

    if len(locs3d) > 1 and np.min(dist) < 1e-10:
        overlapping_mask = np.triu(squareform(dist < 1e-10))
        for chan_idx in range(overlapping_mask.shape[0]):
            already_overlapped = list(
                itertools.chain.from_iterable(overlapping_channels)
            )
            if overlapping_mask[chan_idx].any() and (
                chs[chan_idx]["ch_name"] not in already_overlapped
            ):
                # Determine the set of channels to be combined. Ensure the
                # first listed channel is the one to be replaced with merge
                overlapping_set = [
                    chs[i]["ch_name"] for i in np.where(overlapping_mask[chan_idx])[0]
                ]
                if modality == "fnirs":
                    overlapping_set = np.insert(
                        overlapping_set, 0, (chs[chan_idx]["ch_name"])
                    )
                elif modality == "opm":
                    overlapping_set = np.insert(
                        overlapping_set, 0, (chs[chan_idx]["ch_name"])
                    )
                    rad_channel = _find_radial_channel(info, overlapping_set)
                    # Make sure the radial channel is first in the overlapping set
                    overlapping_set = np.array(
                        [ch for ch in overlapping_set if ch != rad_channel]
                    )
                    overlapping_set = np.insert(overlapping_set, 0, rad_channel)
                overlapping_channels.append(overlapping_set)
                channels_to_exclude.append(overlapping_set[1:])

        exclude = list(itertools.chain.from_iterable(channels_to_exclude))
        [exclude.append(bad) for bad in info["bads"]]
        if modality == "fnirs":
            picks = pick_types(
                info, meg=False, ref_meg=False, fnirs=ch_type, exclude=exclude
            )
            pos = _find_topomap_coords(info, picks, sphere=sphere)
            picks = pick_types(info, meg=False, ref_meg=False, fnirs=ch_type)
        elif modality == "opm":
            picks = pick_types(info, meg=True, ref_meg=False, exclude=exclude)
            pos = _find_topomap_coords(info, picks, sphere=sphere)
            picks = pick_types(info, meg=True, ref_meg=False)

        # Overload the merge_channels variable as this is returned to calling
        # function and indicates that merging of data is required
        merge_channels = overlapping_channels

    else:
        if modality == "fnirs":
            picks = pick_types(
                info, meg=False, ref_meg=False, fnirs=ch_type, exclude="bads"
            )
        elif modality == "opm":
            picks = pick_types(info, meg=True, ref_meg=False, exclude="bads")

        merge_channels = False
        pos = _find_topomap_coords(info, picks, sphere=sphere)

    return picks, pos, merge_channels, overlapping_channels


def _find_radial_channel(info, overlapping_set):
    """Find the most radial channel in the overlapping set."""
    if len(overlapping_set) == 1:
        return overlapping_set[0]
    elif len(overlapping_set) < 1:
        raise ValueError("No overlapping channels found.")

    radial_score = np.zeros(len(overlapping_set))
    for s, sens in enumerate(overlapping_set):
        ch_idx = pick_channels(info["ch_names"], [sens])[0]
        radial_direction = info["chs"][ch_idx]["loc"][0:3].copy()
        radial_direction /= np.linalg.norm(radial_direction)

        orientation_vector = info["chs"][ch_idx]["loc"][9:12]
        if info["dev_head_t"] is not None:
            orientation_vector = apply_trans(
                info["dev_head_t"], orientation_vector, move=False
            )
        radial_score[s] = np.abs(np.dot(radial_direction, orientation_vector))

    radial_sensor = overlapping_set[np.argmax(radial_score)]

    return radial_sensor


def _plot_update_evoked_topomap(params, bools):
    """Update topomaps."""
    from ..channels.layout import _merge_ch_data

    projs = [
        proj for ii, proj in enumerate(params["projs"]) if ii in np.where(bools)[0]
    ]

    params["proj_bools"] = bools
    new_evoked = params["evoked"].copy()
    with new_evoked.info._unlock():
        new_evoked.info["projs"] = []
    new_evoked.add_proj(projs)
    new_evoked.apply_proj()

    data = new_evoked.data[:, params["time_idx"]] * params["scale"]
    if params["merge_channels"]:
        data, _ = _merge_ch_data(data, "grad", [])

    interp = params["interp"]
    new_contours = list()
    use_contours = params["contours_"]
    if not len(use_contours):
        use_contours = [None] * len(params["axes"])
    assert len(use_contours) == len(params["images"])
    assert len(params["axes"]) == len(params["images"])
    assert len(data.T) == len(params["images"])
    for cont, ax, im, d in zip(use_contours, params["axes"], params["images"], data.T):
        Zi = interp.set_values(d)()
        im.set_data(Zi)
        if cont is None:
            continue
        # must be removed and re-added
        cont_collections = _cont_collections(cont)
        for col in cont_collections:
            col.remove()
        col = cont_collections[0]
        lw = col.get_linewidth()
        visible = col.get_visible()
        patch_ = col.get_clip_path()
        color = col.get_edgecolors()
        cont = ax.contour(
            interp.Xi, interp.Yi, Zi, params["contours"], colors=color, linewidths=lw
        )
        cont_collections = _cont_collections(cont)
        for col in cont_collections:
            col.set_visible(visible)
            col.set_clip_path(patch_)
        new_contours.append(cont)
    params["contours_"] = new_contours

    params["fig"].canvas.draw()


def _add_colorbar(
    ax,
    im,
    cmap,
    *,
    title=None,
    format_=None,
    kind=None,
    ch_type=None,
):
    """Add a colorbar to an axis."""
    cbar = ax.figure.colorbar(im, format=format_, shrink=0.6)
    if cmap is not None and cmap[1]:
        ax.CB = DraggableColorbar(cbar, im, kind, ch_type)
    cax = cbar.ax
    if title is not None:
        cax.set_title(title, y=1.05, fontsize=10)
    return cbar, cax


def _eliminate_zeros(proj):
    """Remove grad or mag data if only contains 0s (gh 5641)."""
    GRAD_ENDING = ("2", "3")
    MAG_ENDING = "1"

    proj = copy.deepcopy(proj)
    proj["data"]["data"] = np.atleast_2d(proj["data"]["data"])

    for ending in (GRAD_ENDING, MAG_ENDING):
        names = proj["data"]["col_names"]
        idx = [i for i, name in enumerate(names) if name.endswith(ending)]

        # if all 0, remove the 0s an their labels
        if not proj["data"]["data"][0][idx].any():
            new_col_names = np.delete(np.array(names), idx).tolist()
            new_data = np.delete(np.array(proj["data"]["data"][0]), idx)
            proj["data"]["col_names"] = new_col_names
            proj["data"]["data"] = np.array([new_data])

    proj["data"]["ncol"] = len(proj["data"]["col_names"])
    return proj


@fill_doc
def plot_projs_topomap(
    projs,
    info,
    *,
    sensors=True,
    show_names=False,
    contours=6,
    outlines="head",
    sphere=None,
    image_interp=_INTERPOLATION_DEFAULT,
    extrapolate=_EXTRAPOLATE_DEFAULT,
    border=_BORDER_DEFAULT,
    res=64,
    size=1,
    cmap=None,
    vlim=(None, None),
    cnorm=None,
    colorbar=False,
    cbar_fmt="%3.1f",
    units=None,
    axes=None,
    show=True,
):
    """Plot topographic maps of SSP projections.

    Parameters
    ----------
    projs : list of Projection
        The projections.
    %(info_not_none)s Must be associated with the channels in the projectors.

        .. versionchanged:: 0.20
            The positional argument ``layout`` was replaced by ``info``.
    %(sensors_topomap)s
    %(show_names_topomap)s

        .. versionadded:: 1.2
    %(contours_topomap)s
    %(outlines_topomap)s
    %(sphere_topomap_auto)s
    %(image_interp_topomap)s
    %(extrapolate_topomap)s

        .. versionadded:: 0.20

        .. versionchanged:: 0.21

           - The default was changed to ``'local'`` for MEG sensors.
           - ``'local'`` was changed to use a convex hull mask
           - ``'head'`` was changed to extrapolate out to the clipping circle.
    %(border_topomap)s

        .. versionadded:: 0.20
    %(res_topomap)s
    %(size_topomap)s
    %(cmap_topomap)s
    %(vlim_plot_topomap_proj)s
    %(cnorm)s

        .. versionadded:: 1.2
    %(colorbar_topomap)s
    %(cbar_fmt_topomap)s

        .. versionadded:: 1.2
    %(units_topomap)s

        .. versionadded:: 1.2
    %(axes_plot_projs_topomap)s
    %(show)s

    Returns
    -------
    fig : instance of matplotlib.figure.Figure
        Figure with a topomap subplot for each projector.

    Notes
    -----
    .. versionadded:: 0.9.0
    """
    fig = _plot_projs_topomap(
        projs,
        info,
        sensors=sensors,
        show_names=show_names,
        contours=contours,
        outlines=outlines,
        sphere=sphere,
        image_interp=image_interp,
        extrapolate=extrapolate,
        border=border,
        res=res,
        size=size,
        cmap=cmap,
        vlim=vlim,
        cnorm=cnorm,
        colorbar=colorbar,
        cbar_fmt=cbar_fmt,
        units=units,
        axes=axes,
    )
    with warnings.catch_warnings(record=True):
        warnings.simplefilter("ignore")
    plt_show(show)
    return fig


def _plot_projs_topomap(
    projs,
    info,
    sensors=True,
    show_names=False,
    contours=6,
    outlines="head",
    sphere=None,
    image_interp=_INTERPOLATION_DEFAULT,
    extrapolate=_EXTRAPOLATE_DEFAULT,
    border=_BORDER_DEFAULT,
    res=64,
    size=1,
    cmap=None,
    vlim=(None, None),
    cnorm=None,
    colorbar=False,
    cbar_fmt="%3.1f",
    units=None,
    axes=None,
):
    import matplotlib.pyplot as plt

    from ..channels.layout import _merge_ch_data

    sphere = _check_sphere(sphere, info)
    projs = _check_type_projs(projs)
    _validate_type(info, "info", "info")

    # Preprocess projs to deal with joint MEG projectors. If we duplicate these and
    # split into mag and grad, they should work as expected
    info_names = _clean_names(info["ch_names"], remove_whitespace=True)
    use_projs = list()
    for proj in projs:
        proj = _eliminate_zeros(proj)  # gh 5641, makes a copy
        proj["data"]["col_names"] = _clean_names(
            proj["data"]["col_names"],
            remove_whitespace=True,
        )
        picks = pick_channels(info_names, proj["data"]["col_names"], ordered=True)
        proj_types = info.get_channel_types(picks)
        unique_types = sorted(set(proj_types))
        for type_ in unique_types:
            proj_picks = np.where([proj_type == type_ for proj_type in proj_types])[0]
            use_projs.append(copy.deepcopy(proj))
            use_projs[-1]["data"]["data"] = proj["data"]["data"][:, proj_picks]
            use_projs[-1]["data"]["col_names"] = [
                proj["data"]["col_names"][pick] for pick in proj_picks
            ]
    projs = use_projs

    datas, poss, spheres, outliness, ch_typess = [], [], [], [], []
    for proj in projs:
        # get ch_names, ch_types, data
        data = proj["data"]["data"].ravel()
        picks = pick_channels(info_names, proj["data"]["col_names"], ordered=True)
        use_info = pick_info(info, picks)
        these_ch_types = use_info.get_channel_types(unique=True)
        assert len(these_ch_types) == 1  # should be guaranteed above
        ch_type = these_ch_types[0]
        (
            data_picks,
            pos,
            merge_channels,
            names,
            _,
            this_sphere,
            clip_origin,
        ) = _prepare_topomap_plot(use_info, ch_type, sphere=sphere)
        these_outlines = _make_head_outlines(sphere, pos, outlines, clip_origin)
        data = data[data_picks]
        if merge_channels:
            data, _ = _merge_ch_data(data, "grad", [])
            data = data.ravel()

        # populate containers
        datas.append(data)
        poss.append(pos)
        spheres.append(this_sphere)
        outliness.append(these_outlines)
        ch_typess.append(ch_type)
        del data, pos, this_sphere, these_outlines, ch_type
    del sphere

    # setup axes
    n_projs = len(projs)
    if axes is None:
        fig, axes, ncols, nrows = _prepare_trellis(
            n_projs, ncols="auto", nrows="auto", size=size, sharex=True, sharey=True
        )
    elif isinstance(axes, plt.Axes):
        axes = [axes]
    _validate_if_list_of_axes(axes, n_projs)

    # handle vmin/vmax
    vlims = [None for _ in range(len(datas))]
    if vlim == "joint":
        for _ch_type in set(ch_typess):
            idx = np.where(np.isin(ch_typess, _ch_type))[0]
            these_data = np.concatenate(np.array(datas, dtype=object)[idx])
            norm = all(these_data >= 0)
            _vl = _setup_vmin_vmax(these_data, vmin=None, vmax=None, norm=norm)
            for _idx in idx:
                vlims[_idx] = _vl
        # make sure we got a vlim for all projs
        assert all([vl is not None for vl in vlims])
    else:
        vlims = [vlim] * len(datas)

    # plot
    for proj, ax, _data, _pos, _vlim, _sphere, _outlines, _ch_type in zip(
        projs, axes, datas, poss, vlims, spheres, outliness, ch_typess
    ):
        # ch_names
        names = [info["ch_names"][k] for k in _picks_to_idx(info, _ch_type)]
        names = _prepare_sensor_names(names, show_names)
        # title
        title = proj["desc"]
        title = "\n".join(title[ii : ii + 22] for ii in range(0, len(title), 22))
        ax.set_title(title, fontsize=10)
        # plot
        im, _ = plot_topomap(
            _data,
            _pos[:, :2],
            vlim=_vlim,
            cmap=cmap,
            sensors=sensors,
            names=names,
            res=res,
            axes=ax,
            outlines=_outlines,
            contours=contours,
            cnorm=cnorm,
            image_interp=image_interp,
            show=False,
            extrapolate=extrapolate,
            sphere=_sphere,
            border=border,
            ch_type=_ch_type,
        )

        if colorbar:
            _add_colorbar(
                ax,
                im,
                cmap,
                title=units,
                format_=cbar_fmt,
                kind="projs_topomap",
                ch_type=_ch_type,
            )

    return ax.get_figure()


def _make_head_outlines(sphere, pos, outlines, clip_origin):
    """Check or create outlines for topoplot."""
    assert isinstance(sphere, np.ndarray)
    x, y, _, radius = sphere
    del sphere

    if outlines in ("head", None):
        ll = np.linspace(0, 2 * np.pi, 101)
        head_x = np.cos(ll) * radius + x
        head_y = np.sin(ll) * radius + y
        dx = np.exp(np.arccos(np.deg2rad(12)) * 1j)
        dx, dy = dx.real, dx.imag
        nose_x = np.array([-dx, 0, dx]) * radius + x
        nose_y = np.array([dy, 1.15, dy]) * radius + y
        ear_x = np.array(
            [0.497, 0.510, 0.518, 0.5299, 0.5419, 0.54, 0.547, 0.532, 0.510, 0.489]
        ) * (radius * 2)
        ear_y = (
            np.array(
                [
                    0.0555,
                    0.0775,
                    0.0783,
                    0.0746,
                    0.0555,
                    -0.0055,
                    -0.0932,
                    -0.1313,
                    -0.1384,
                    -0.1199,
                ]
            )
            * (radius * 2)
            + y
        )

        if outlines is not None:
            # Define the outline of the head, ears and nose
            outlines_dict = dict(
                head=(head_x, head_y),
                nose=(nose_x, nose_y),
                ear_left=(-ear_x + x, ear_y),
                ear_right=(ear_x + x, ear_y),
            )
        else:
            outlines_dict = dict()

        # Make the figure encompass slightly more than all points
        # We probably want to ensure it always contains our most
        # extremely positioned channels, so we do:
        mask_scale = max(1.0, np.linalg.norm(pos, axis=1).max() * 1.01 / radius)
        outlines_dict["mask_pos"] = (mask_scale * head_x, mask_scale * head_y)
        clip_radius = radius * mask_scale
        outlines_dict["clip_radius"] = (clip_radius,) * 2
        outlines_dict["clip_origin"] = clip_origin
        outlines = outlines_dict

    elif isinstance(outlines, dict):
        if "mask_pos" not in outlines:
            raise ValueError("You must specify the coordinates of the image mask.")
    else:
        raise ValueError("Invalid value for `outlines`.")

    return outlines


def _draw_outlines(ax, outlines):
    """Draw the outlines for a topomap."""
    from matplotlib import rcParams

    outlines_ = {k: v for k, v in outlines.items() if k not in ["patch"]}
    for key, (x_coord, y_coord) in outlines_.items():
        if "mask" in key or key in ("clip_radius", "clip_origin"):
            continue
        ax.plot(
            x_coord,
            y_coord,
            color=rcParams["axes.edgecolor"],
            linewidth=1,
            clip_on=False,
        )
    return outlines_


def _get_extra_points(pos, extrapolate, origin, radii):
    """Get coordinates of additional interpolation points."""
    radii = np.array(radii, float)
    assert radii.shape == (2,)
    x, y = origin
    # auto should be gone by now
    _check_option("extrapolate", extrapolate, ("head", "box", "local"))

    # the old method of placement - large box
    mask_pos = None
    if extrapolate == "box":
        extremes = np.array([pos.min(axis=0), pos.max(axis=0)])
        diffs = extremes[1] - extremes[0]
        extremes[0] -= diffs
        extremes[1] += diffs
        eidx = np.array(
            list(itertools.product(*([[0] * (pos.shape[1] - 1) + [1]] * pos.shape[1])))
        )
        pidx = np.tile(np.arange(pos.shape[1])[np.newaxis], (len(eidx), 1))
        outer_pts = extremes[eidx, pidx]
        return outer_pts, mask_pos, Delaunay(np.concatenate((pos, outer_pts)))

    # check if positions are colinear:
    diffs = np.diff(pos, axis=0)
    with np.errstate(divide="ignore"):
        slopes = diffs[:, 1] / diffs[:, 0]
    colinear = (slopes == slopes[0]).all() or np.isinf(slopes).all()

    # compute median inter-electrode distance
    if colinear or pos.shape[0] < 4:
        dim = 1 if diffs[:, 1].sum() > diffs[:, 0].sum() else 0
        sorting = np.argsort(pos[:, dim])
        pos_sorted = pos[sorting, :]
        diffs = np.diff(pos_sorted, axis=0)
        distances = np.linalg.norm(diffs, axis=1)
        distance = np.median(distances)
    else:
        tri = Delaunay(pos, incremental=True)
        idx1, idx2, idx3 = tri.simplices.T
        distances = np.concatenate(
            [
                np.linalg.norm(pos[i1, :] - pos[i2, :], axis=1)
                for i1, i2 in zip([idx1, idx2], [idx2, idx3])
            ]
        )
        distance = np.median(distances)

    if extrapolate == "local":
        if colinear or pos.shape[0] < 4:
            # special case for colinear points and when there is too
            # little points for Delaunay (needs at least 3)
            edge_points = sorting[[0, -1]]
            line_len = np.diff(pos[edge_points, :], axis=0)
            unit_vec = line_len / np.linalg.norm(line_len) * distance
            unit_vec_par = unit_vec[:, ::-1] * [[-1, 1]]

            edge_pos = pos[edge_points, :] + np.concatenate(
                [-unit_vec, unit_vec], axis=0
            )
            new_pos = np.concatenate(
                [pos + unit_vec_par, pos - unit_vec_par, edge_pos], axis=0
            )

            if pos.shape[0] == 3:
                # there may be some new_pos points that are too close
                # to the original points
                new_pos_diff = pos[..., np.newaxis] - new_pos.T[np.newaxis, :]
                new_pos_diff = np.linalg.norm(new_pos_diff, axis=1)
                good_extra = (new_pos_diff > 0.5 * distance).all(axis=0)
                new_pos = new_pos[good_extra]

            tri = Delaunay(np.concatenate([pos, new_pos], axis=0))
            return new_pos, new_pos, tri

        # get the convex hull of data points from triangulation
        hull_pos = pos[tri.convex_hull]

        # extend the convex hull limits outwards a bit
        channels_center = pos.mean(axis=0)
        radial_dir = hull_pos - channels_center
        unit_radial_dir = radial_dir / np.linalg.norm(
            radial_dir, axis=-1, keepdims=True
        )
        hull_extended = hull_pos + unit_radial_dir * distance
        mask_pos = hull_pos + unit_radial_dir * distance * 0.5
        hull_diff = np.diff(hull_pos, axis=1)[:, 0]
        hull_distances = np.linalg.norm(hull_diff, axis=-1)
        del channels_center

        # Construct a mask
        mask_pos = np.unique(mask_pos.reshape(-1, 2), axis=0)
        mask_center = np.mean(mask_pos, axis=0)
        mask_pos -= mask_center
        mask_pos = mask_pos[np.argsort(np.arctan2(mask_pos[:, 1], mask_pos[:, 0]))]
        mask_pos += mask_center

        # add points along hull edges so that the distance between points
        # is around that of average distance between channels
        add_points = list()
        eps = np.finfo("float").eps
        n_times_dist = np.round(0.25 * hull_distances / distance).astype("int")
        for n in range(2, n_times_dist.max() + 1):
            mask = n_times_dist == n
            mult = np.arange(1 / n, 1 - eps, 1 / n)[:, np.newaxis, np.newaxis]
            steps = hull_diff[mask][np.newaxis, ...] * mult
            add_points.append(
                (hull_extended[mask, 0][np.newaxis, ...] + steps).reshape((-1, 2))
            )

        # remove duplicates from hull_extended
        hull_extended = np.unique(hull_extended.reshape((-1, 2)), axis=0)
        new_pos = np.concatenate([hull_extended] + add_points)
    else:
        assert extrapolate == "head"
        # return points on the head circle
        angle = np.arcsin(min(distance / np.mean(radii), 1))
        n_pnts = max(12, int(np.round(2 * np.pi / angle)))
        points_l = np.linspace(0, 2 * np.pi, n_pnts, endpoint=False)
        use_radii = radii * 1.1 + distance
        points_x = np.cos(points_l) * use_radii[0] + x
        points_y = np.sin(points_l) * use_radii[1] + y
        new_pos = np.stack([points_x, points_y], axis=1)
        if colinear or pos.shape[0] == 3:
            tri = Delaunay(np.concatenate([pos, new_pos], axis=0))
            return new_pos, mask_pos, tri
    tri.add_points(new_pos)
    return new_pos, mask_pos, tri


class _GridData:
    """Unstructured (x,y) data interpolator.

    This class allows optimized interpolation by computing parameters
    for a fixed set of true points, and allowing the values at those points
    to be set independently.
    """

    def __init__(self, pos, image_interp, extrapolate, origin, radii, border):
        # in principle this works in N dimensions, not just 2
        assert pos.ndim == 2 and pos.shape[1] == 2, pos.shape
        _validate_type(border, ("numeric", str), "border")

        # check that border, if string, is correct
        if isinstance(border, str):
            _check_option("border", border, ("mean",), extra="when a string")

        # Adding points outside the extremes helps the interpolators
        outer_pts, mask_pts, tri = _get_extra_points(pos, extrapolate, origin, radii)
        self.n_extra = outer_pts.shape[0]
        self.mask_pts = mask_pts
        self.border = border
        self.tri = tri
        self.interp = {
            "cubic": CloughTocher2DInterpolator,
            "nearest": NearestNDInterpolator,  # used only for anim
            "linear": LinearNDInterpolator,
        }[image_interp]

    def set_values(self, v):
        """Set the values at interpolation points."""
        # Rbf with thin-plate is what we used to use, but it's slower and
        # looks about the same:
        #
        #     zi = Rbf(x, y, v, function='multiquadric', smooth=0)(xi, yi)
        #
        # Eventually we could also do set_values with this class if we want,
        # see scipy/interpolate/rbf.py, especially the self.nodes one-liner.
        if isinstance(self.border, str):
            # we've already checked that border = 'mean'
            n_points = v.shape[0]
            v_extra = np.zeros(self.n_extra)
            indices, indptr = self.tri.vertex_neighbor_vertices
            rng = range(n_points, n_points + self.n_extra)
            used = np.zeros(len(rng), bool)
            for idx, extra_idx in enumerate(rng):
                ngb = indptr[indices[extra_idx] : indices[extra_idx + 1]]
                ngb = ngb[ngb < n_points]
                if len(ngb) > 0:
                    used[idx] = True
                    v_extra[idx] = v[ngb].mean()
            if not used.all() and used.any():
                # Eventually we might want to use the value of the nearest
                # point or something, but this case should hopefully be
                # rare so for now just use the average value of all extras
                v_extra[~used] = np.mean(v_extra[used])
        else:
            v_extra = np.full(self.n_extra, self.border, dtype=float)

        v = np.concatenate((v, v_extra))
        self.interpolator = self.interp(self.tri, v)
        return self

    def set_locations(self, Xi, Yi):
        """Set locations for easier (delayed) calling."""
        self.Xi = Xi
        self.Yi = Yi
        return self

    def __call__(self, *args):
        """Evaluate the interpolator."""
        if len(args) == 0:
            args = [self.Xi, self.Yi]
        return self.interpolator(*args)


def _topomap_plot_sensors(pos_x, pos_y, sensors, ax):
    """Plot sensors."""
    if sensors is True:
        ax.scatter(
            pos_x,
            pos_y,
            s=0.25,
            marker="o",
            edgecolor=["k"] * len(pos_x),
            facecolor="none",
        )
    else:
        ax.plot(pos_x, pos_y, sensors)


def _get_pos_outlines(info, picks, sphere, to_sphere=True):
    from ..channels.layout import _find_topomap_coords

    picks = _picks_to_idx(info, picks, "all", exclude=(), allow_empty=False)
    ch_type = _get_plot_ch_type(pick_info(_simplify_info(info), picks), None)
    orig_sphere = sphere
    sphere, clip_origin = _adjust_meg_sphere(sphere, info, ch_type)
    logger.debug(
        f"Generating pos outlines with sphere {sphere} from {orig_sphere} for {ch_type}"
    )
    pos = _find_topomap_coords(
        info, picks, ignore_overlap=True, to_sphere=to_sphere, sphere=sphere
    )
    outlines = _make_head_outlines(sphere, pos, "head", clip_origin)
    return pos, outlines


@fill_doc
def plot_topomap(
    data,
    pos,
    *,
    ch_type="eeg",
    sensors=True,
    names=None,
    mask=None,
    mask_params=None,
    contours=6,
    outlines="head",
    sphere=None,
    image_interp=_INTERPOLATION_DEFAULT,
    extrapolate=_EXTRAPOLATE_DEFAULT,
    border=_BORDER_DEFAULT,
    res=64,
    size=1,
    cmap=None,
    vlim=(None, None),
    cnorm=None,
    axes=None,
    show=True,
    onselect=None,
):
    """Plot a topographic map as image.

    Parameters
    ----------
    data : array, shape (n_chan,)
        The data values to plot.
    %(pos_topomap)s
        If an :class:`~mne.Info` object it must contain only one channel type
        and exactly ``len(data)`` channels; the x/y coordinates will
        be inferred from the montage in the :class:`~mne.Info` object.
    %(ch_type_topomap)s

        .. versionadded:: 0.21
    %(sensors_topomap)s
    %(names_topomap)s
    %(mask_topomap)s
    %(mask_params_topomap)s
    %(contours_topomap)s
    %(outlines_topomap)s
    %(sphere_topomap_auto)s
    %(image_interp_topomap)s
    %(extrapolate_topomap)s

        .. versionadded:: 0.18

        .. versionchanged:: 0.21

           - The default was changed to ``'local'`` for MEG sensors.
           - ``'local'`` was changed to use a convex hull mask
           - ``'head'`` was changed to extrapolate out to the clipping circle.
    %(border_topomap)s

        .. versionadded:: 0.20
    %(res_topomap)s
    %(size_topomap)s
    %(cmap_topomap)s
    %(vlim_plot_topomap)s

        .. versionadded:: 1.2
    %(cnorm)s

        .. versionadded:: 0.24
    %(axes_plot_topomap)s

        .. versionchanged:: 1.2
           If ``axes=None``, a new :class:`~matplotlib.figure.Figure` is
           created instead of plotting into the current axes.
    %(show)s
    onselect : callable | None
        A function to be called when the user selects a set of channels by
        click-dragging (uses a matplotlib
        :class:`~matplotlib.widgets.RectangleSelector`). If ``None``
        interactive channel selection is disabled. Defaults to ``None``.

    Returns
    -------
    im : matplotlib.image.AxesImage
        The interpolated data.
    cn : matplotlib.contour.ContourSet
        The fieldlines.
    """
    import matplotlib.pyplot as plt
    from matplotlib.colors import Normalize

    if axes is None:
        _, axes = plt.subplots(figsize=(size, size), layout="constrained")
    sphere = _check_sphere(sphere, pos if isinstance(pos, Info) else None)
    _validate_type(cnorm, (Normalize, None), "cnorm")
    if cnorm is not None and (vlim[0] is not None or vlim[1] is not None):
        warn(
            f"Provided cnorm implicitly defines vmin={cnorm.vmin} and "
            f"vmax={cnorm.vmax}; ignoring additional vlim/vmin/vmax params."
        )
    return _plot_topomap(
        data,
        pos,
        vmin=vlim[0],
        vmax=vlim[1],
        cmap=cmap,
        sensors=sensors,
        res=res,
        axes=axes,
        names=names,
        mask=mask,
        mask_params=mask_params,
        outlines=outlines,
        contours=contours,
        image_interp=image_interp,
        show=show,
        onselect=onselect,
        extrapolate=extrapolate,
        sphere=sphere,
        border=border,
        ch_type=ch_type,
        cnorm=cnorm,
    )[:2]


def _setup_interp(pos, res, image_interp, extrapolate, outlines, border):
    if image_interp not in ("cubic", "linear", "nearest"):
        raise RuntimeError(
            "`image_interp` must be `cubic`, `linear` or `nearest`, got "
            f"{image_interp}. Previously, `image_interp` controlled "
            "the matplotlib image interpolation after an original cubic "
            "interpolation but this was changed to control the first "
            "interpolation instead for simplicity. To change the "
            "matplotlib image interpolation, use "
            "`im.set_interpolation(...)"
        )
    logger.debug(
        f"Interpolation mode {image_interp}, "
        f"extrapolation mode {extrapolate} to {border}"
    )
    xlim = (
        np.inf,
        -np.inf,
    )
    ylim = (
        np.inf,
        -np.inf,
    )
    mask_ = np.c_[outlines["mask_pos"]]
    clip_radius = outlines["clip_radius"]
    clip_origin = outlines.get("clip_origin", (0.0, 0.0))
    xmin, xmax = (
        np.min(np.r_[xlim[0], mask_[:, 0], clip_origin[0] - clip_radius[0]]),
        np.max(np.r_[xlim[1], mask_[:, 0], clip_origin[0] + clip_radius[0]]),
    )
    ymin, ymax = (
        np.min(np.r_[ylim[0], mask_[:, 1], clip_origin[1] - clip_radius[1]]),
        np.max(np.r_[ylim[1], mask_[:, 1], clip_origin[1] + clip_radius[1]]),
    )
    xi = np.linspace(xmin, xmax, res)
    yi = np.linspace(ymin, ymax, res)
    Xi, Yi = np.meshgrid(xi, yi)
    interp = _GridData(pos, image_interp, extrapolate, clip_origin, clip_radius, border)
    extent = (xmin, xmax, ymin, ymax)
    return extent, Xi, Yi, interp


_VORONOI_CIRCLE_RES = 100


def _voronoi_topomap(data, pos, outlines, ax, cmap, norm, extent, res):
    """Make a Voronoi diagram on a topomap."""
    # we need an image axis object so first empty image to plot over
    im = ax.imshow(
        np.zeros((res, res)) * np.nan,
        cmap=cmap,
        origin="lower",
        aspect="equal",
        extent=extent,
        norm=norm,
    )
    rx, ry = outlines["clip_radius"]
    cx, cy = outlines.get("clip_origin", (0.0, 0.0))
    # add points on the circle to make boundaries, expand out to
    # ensure regions extend to the edge of the topomap
    vor = Voronoi(
        np.concatenate(
            [
                pos,
                [
                    (
                        rx * 1.5 * np.cos(2 * np.pi / _VORONOI_CIRCLE_RES * t),
                        ry * 1.5 * np.sin(2 * np.pi / _VORONOI_CIRCLE_RES * t),
                    )
                    for t in range(_VORONOI_CIRCLE_RES)
                ],
            ]
        )
    )
    for point_idx, region_idx in enumerate(vor.point_region[:-_VORONOI_CIRCLE_RES]):
        if -1 in vor.regions[region_idx]:
            continue
        polygon = list()
        for i in vor.regions[region_idx]:
            x, y = vor.vertices[i]
            if (x - cx) ** 2 / rx**2 + (y - cy) ** 2 / ry**2 < 1:
                polygon.append((x, y))
            else:
                x *= rx / np.linalg.norm(vor.vertices[i])
                y *= ry / np.linalg.norm(vor.vertices[i])
                polygon.append((x, y))
        ax.fill(*zip(*polygon), color=cmap(norm(data[point_idx])))
    return im


def _get_patch(outlines, extrapolate, interp, ax):
    from matplotlib import patches

    clip_radius = outlines["clip_radius"]
    clip_origin = outlines.get("clip_origin", (0.0, 0.0))
    _use_default_outlines = any(k.startswith("head") for k in outlines)
    patch_ = None
    if "patch" in outlines:
        patch_ = outlines["patch"]
        patch_ = patch_() if callable(patch_) else patch_
        patch_.set_clip_on(False)
        ax.add_patch(patch_)
        ax.set_transform(ax.transAxes)
        ax.set_clip_path(patch_)
    if _use_default_outlines:
        if extrapolate == "local":
            patch_ = patches.Polygon(
                interp.mask_pts, clip_on=True, transform=ax.transData
            )
        else:
            patch_ = patches.Ellipse(
                clip_origin,
                2 * clip_radius[0],
                2 * clip_radius[1],
                clip_on=True,
                transform=ax.transData,
            )
    return patch_


def _plot_topomap(
    data,
    pos,
    axes,
    *,
    ch_type="eeg",
    sensors=True,
    names=None,
    mask=None,
    mask_params=None,
    contours=6,
    outlines="head",
    sphere=None,
    image_interp=_INTERPOLATION_DEFAULT,
    extrapolate=_EXTRAPOLATE_DEFAULT,
    border=_BORDER_DEFAULT,
    res=64,
    cmap=None,
    vmin=None,
    vmax=None,
    cnorm=None,
    show=True,
    onselect=None,
):
    from matplotlib.colors import Normalize
    from matplotlib.widgets import RectangleSelector

    from ..channels.layout import (
        _find_topomap_coords,
        _merge_ch_data,
        _pair_grad_sensors,
    )

    data = np.asarray(data)
    logger.debug(f"Plotting topomap for {ch_type} data shape {data.shape}")

    if isinstance(pos, Info):  # infer pos from Info object
        picks = _pick_data_channels(pos, exclude=())  # pick only data channels
        pos = pick_info(pos, picks)

        # check if there is only 1 channel type, and n_chans matches the data
        ch_type = pos.get_channel_types(picks=None, unique=True)
        info_help = (
            "Pick the Info object "
            "(e.g., using mne.pick_info and mne.channel_indices_by_type)."
        )
        if len(ch_type) > 1:
            raise ValueError(f"Multiple channel types in Info object. {info_help}")
        elif len(pos["chs"]) != data.shape[0]:
            raise ValueError(
                f"Number of channels in the Info object ({len(pos['chs'])}) and the "
                f"data array ({data.shape[0]}) do not match. {info_help}"
            )
        else:
            ch_type = ch_type.pop()

        if any(type_ in ch_type for type_ in ("planar", "grad")):
            # deal with grad pairs
            picks = _pair_grad_sensors(pos, topomap_coords=False)
            pos = _find_topomap_coords(pos, picks=picks[::2], sphere=sphere)
            data, _ = _merge_ch_data(data[picks], ch_type, [])
            data = data.reshape(-1)
        else:
            picks = list(range(data.shape[0]))
            pos = _find_topomap_coords(pos, picks=picks, sphere=sphere)

    extrapolate = _check_extrapolate(extrapolate, ch_type)
    if data.ndim > 1:
        raise ValueError(
            f"Data needs to be array of shape (n_sensors,); got shape {data.shape}."
        )

    # Give a helpful error message for common mistakes regarding the position
    # matrix.
    pos_help = (
        "Electrode positions should be specified as a 2D array with "
        "shape (n_channels, 2). Each row in this matrix contains the "
        "(x, y) position of an electrode."
    )
    if pos.ndim != 2:
        error = (
            f"{pos.ndim}D array supplied as electrode positions, where a 2D array was "
            "expected"
        )
        raise ValueError(error + " " + pos_help)
    elif pos.shape[1] == 3:
        error = (
            "The supplied electrode positions matrix contains 3 columns. "
            "Are you trying to specify XYZ coordinates? Perhaps the "
            "mne.channels.create_eeg_layout function is useful for you."
        )
        raise ValueError(error + " " + pos_help)
    # No error is raised in case of pos.shape[1] == 4. In this case, it is
    # assumed the position matrix contains both (x, y) and (width, height)
    # values, such as Layout.pos.
    elif pos.shape[1] == 1 or pos.shape[1] > 4:
        raise ValueError(pos_help)
    pos = pos[:, :2]

    if len(data) != len(pos):
        raise ValueError(
            "Data and pos need to be of same length. Got data of "
            f"length {len(data)}, pos of length {len(pos)}"
        )

    norm = min(data) >= 0
    vmin, vmax = _setup_vmin_vmax(data, vmin, vmax, norm)
    if cmap is None:
        cmap = "Reds" if norm else "RdBu_r"
    cmap = _get_cmap(cmap)

    outlines = _make_head_outlines(sphere, pos, outlines, (0.0, 0.0))
    assert isinstance(outlines, dict)

    _prepare_topomap(pos, axes)

    mask_params = _handle_default("mask_params", mask_params)

    # find mask limits and setup interpolation
    extent, Xi, Yi, interp = _setup_interp(
        pos, res, image_interp, extrapolate, outlines, border
    )
    interp.set_values(data)
    Zi = interp.set_locations(Xi, Yi)()

    # plot outline
    patch_ = _get_patch(outlines, extrapolate, interp, axes)

    # get colormap normalization
    if cnorm is None:
        cnorm = Normalize(vmin=vmin, vmax=vmax)

    # plot interpolated map
    if image_interp == "nearest":  # plot over with Voronoi, more accurate
        im = _voronoi_topomap(
            data,
            pos=pos,
            outlines=outlines,
            ax=axes,
            cmap=cmap,
            norm=cnorm,
            extent=extent,
            res=res,
        )
    else:
        im = axes.imshow(
            Zi,
            cmap=cmap,
            origin="lower",
            aspect="equal",
            extent=extent,
            interpolation="bilinear",
            norm=cnorm,
        )

    # gh-1432 had a workaround for no contours here, but we'll remove it
    # because mpl has probably fixed it
    linewidth = mask_params["markeredgewidth"]
    cont = True
    if isinstance(contours, np.ndarray | list):
        pass
    elif contours == 0 or ((Zi == Zi[0, 0]) | np.isnan(Zi)).all():
        cont = None  # can't make contours for constant-valued functions
    if cont:
        with warnings.catch_warnings(record=True):
            warnings.simplefilter("ignore")
            cont = axes.contour(
                Xi, Yi, Zi, contours, colors="k", linewidths=linewidth / 2.0
            )

    if patch_ is not None:
        im.set_clip_path(patch_)
        if cont is not None:
            for col in _cont_collections(cont):
                col.set_clip_path(patch_)

    pos_x, pos_y = pos.T
    mask = mask.astype(bool, copy=False) if mask is not None else None
    if sensors is not False and mask is None:
        _topomap_plot_sensors(pos_x, pos_y, sensors=sensors, ax=axes)
    elif sensors and mask is not None:
        idx = np.where(mask)[0]
        axes.plot(pos_x[idx], pos_y[idx], **mask_params)
        idx = np.where(~mask)[0]
        _topomap_plot_sensors(pos_x[idx], pos_y[idx], sensors=sensors, ax=axes)
    elif not sensors and mask is not None:
        idx = np.where(mask)[0]
        axes.plot(pos_x[idx], pos_y[idx], **mask_params)

    if isinstance(outlines, dict):
        _draw_outlines(axes, outlines)

    if names is not None and sensors:
        for _pos, _name in zip(pos, names):
            axes.text(
                _pos[0],
                _pos[1],
                _name,
                horizontalalignment="center",
                verticalalignment="center",
                size="x-small",
            )

    if onselect is not None:
        lim = axes.dataLim
        x0, y0, width, height = lim.x0, lim.y0, lim.width, lim.height
        axes.RS = RectangleSelector(axes, onselect=onselect)
        axes.set(xlim=[x0, x0 + width], ylim=[y0, y0 + height])
    plt_show(show)
    return im, cont, interp


def _plot_ica_topomap(
    ica,
    idx=0,
    ch_type=None,
    res=64,
    vmin=None,
    vmax=None,
    cmap="RdBu_r",
    colorbar=False,
    title=None,
    show=True,
    outlines="head",
    contours=6,
    image_interp=_INTERPOLATION_DEFAULT,
    axes=None,
    sensors=True,
    allow_ref_meg=False,
    extrapolate=_EXTRAPOLATE_DEFAULT,
    sphere=None,
    border=_BORDER_DEFAULT,
):
    """Plot single ica map to axes."""
    from matplotlib.axes import Axes

    from ..channels.layout import _merge_ch_data

    if ica.info is None:
        raise RuntimeError(
            "The ICA's measurement info is missing. Please "
            "fit the ICA or add the corresponding info object."
        )
    sphere = _check_sphere(sphere, ica.info)
    if not isinstance(axes, Axes):
        raise ValueError(
            f"axis has to be an instance of matplotlib Axes, got {type(axes)} instead."
        )
    ch_type = _get_plot_ch_type(ica, ch_type, allow_ref_meg=ica.allow_ref_meg)
    if ch_type == "ref_meg":
        logger.info("Cannot produce topographies for MEG reference channels.")
        return

    data = ica.get_components()[:, idx]
    (
        data_picks,
        pos,
        merge_channels,
        names,
        _,
        sphere,
        clip_origin,
    ) = _prepare_topomap_plot(ica, ch_type, sphere=sphere)
    data = data[data_picks]
    outlines = _make_head_outlines(sphere, pos, outlines, clip_origin)

    if merge_channels:
        data, names = _merge_ch_data(data, ch_type, names)

    topo_title = ica._ica_names[idx]
    if len(set(ica.get_channel_types())) > 1:
        topo_title += f" ({ch_type})"
    axes.set_title(topo_title, fontsize=12)
    vlim = _setup_vmin_vmax(data, vmin, vmax)
    im = plot_topomap(
        data.ravel(),
        pos,
        vlim=vlim,
        res=res,
        axes=axes,
        cmap=cmap,
        outlines=outlines,
        contours=contours,
        sensors=sensors,
        image_interp=image_interp,
        show=show,
        extrapolate=extrapolate,
        sphere=sphere,
        border=border,
        ch_type=ch_type,
    )[0]
    if colorbar:
        cbar, cax = _add_colorbar(
            axes,
            im,
            cmap,
            title="AU",
            format_="%3.2f",
            kind="ica_topomap",
            ch_type=ch_type,
        )
        cbar.ax.tick_params(labelsize=12)
        cbar.set_ticks(vlim)
    _hide_frame(axes)


@verbose
def plot_ica_components(
    ica,
    picks=None,
    ch_type=None,
    *,
    inst=None,
    plot_std=True,
    reject="auto",
    sensors=True,
    show_names=False,
    contours=6,
    outlines="head",
    sphere=None,
    image_interp=_INTERPOLATION_DEFAULT,
    extrapolate=_EXTRAPOLATE_DEFAULT,
    border=_BORDER_DEFAULT,
    res=64,
    size=1,
    cmap="RdBu_r",
    vlim=(None, None),
    cnorm=None,
    colorbar=False,
    cbar_fmt="%3.2f",
    axes=None,
    title=None,
    nrows="auto",
    ncols="auto",
    show=True,
    image_args=None,
    psd_args=None,
    verbose=None,
):
    """Project mixing matrix on interpolated sensor topography.

    Parameters
    ----------
    ica : instance of mne.preprocessing.ICA
        The ICA solution.
    %(picks_ica)s
    %(ch_type_topomap)s
    inst : Raw | Epochs | None
        To be able to see component properties after clicking on component
        topomap you need to pass relevant data - instances of Raw or Epochs
        (for example the data that ICA was trained on). This takes effect
        only when running matplotlib in interactive mode.
    plot_std : bool | float
        Whether to plot standard deviation in ERP/ERF and spectrum plots.
        Defaults to True, which plots one standard deviation above/below.
        If set to float allows to control how many standard deviations are
        plotted. For example 2.5 will plot 2.5 standard deviation above/below.
    reject : ``'auto'`` | dict | None
        Allows to specify rejection parameters used to drop epochs
        (or segments if continuous signal is passed as inst).
        If None, no rejection is applied. The default is 'auto',
        which applies the rejection parameters used when fitting
        the ICA object.
    %(sensors_topomap)s
    %(show_names_topomap)s
    %(contours_topomap)s
    %(outlines_topomap)s
    %(sphere_topomap_auto)s
    %(image_interp_topomap)s
    %(extrapolate_topomap)s

        .. versionadded:: 1.3
    %(border_topomap)s

        .. versionadded:: 1.3
    %(res_topomap)s
    %(size_topomap)s

        .. versionadded:: 1.3
    %(cmap_topomap)s
    %(vlim_plot_topomap)s

        .. versionadded:: 1.3
    %(cnorm)s

        .. versionadded:: 1.3
    %(colorbar_topomap)s
    %(cbar_fmt_topomap)s
    axes : Axes | array of Axes | None
        The subplot(s) to plot to. Either a single Axes or an iterable of Axes
        if more than one subplot is needed. The number of subplots must match
        the number of selected components. If None, new figures will be created
        with the number of subplots per figure controlled by ``nrows`` and
        ``ncols``.
    title : str | None
        The title of the generated figure. If ``None`` (default) and
        ``axes=None``, a default title of "ICA Components" will be used.
    %(nrows_ncols_ica_components)s

        .. versionadded:: 1.3
    %(show)s
    image_args : dict | None
        Dictionary of arguments to pass to :func:`~mne.viz.plot_epochs_image`
        in interactive mode. Ignored if ``inst`` is not supplied. If ``None``,
        nothing is passed. Defaults to ``None``.
    psd_args : dict | None
        Dictionary of arguments to pass to :meth:`~mne.Epochs.compute_psd` in
        interactive  mode. Ignored if ``inst`` is not supplied. If ``None``,
        nothing is passed. Defaults to ``None``.
    %(verbose)s

    Returns
    -------
    fig : instance of matplotlib.figure.Figure | list of matplotlib.figure.Figure
        The figure object(s).

    Notes
    -----
    When run in interactive mode, ``plot_ica_components`` allows to reject
    components by clicking on their title label. The state of each component
    is indicated by its label color (gray: rejected; black: retained). It is
    also possible to open component properties by clicking on the component
    topomap (this option is only available when the ``inst`` argument is
    supplied).
    """  # noqa E501
    from matplotlib.pyplot import Axes

    from ..channels.layout import _merge_ch_data
    from ..epochs import BaseEpochs
    from ..io import BaseRaw

    if ica.info is None:
        raise RuntimeError(
            "The ICA's measurement info is missing. Please "
            "fit the ICA or add the corresponding info object."
        )

    # for backward compat, nrow='auto' ncol='auto' should yield 4 rows 5 cols
    # and create multiple figures if more than 20 components requested
    if nrows == "auto" and ncols == "auto":
        ncols = 5
        max_subplots = 20
    elif nrows == "auto" or ncols == "auto":
        # user provided incomplete row/col spec; put all in one figure
        max_subplots = ica.n_components_
    else:
        max_subplots = nrows * ncols

    # handle ch_type=None
    ch_type = _get_plot_ch_type(ica, ch_type)

    figs = []
    if picks is None:
        cut_points = range(max_subplots, ica.n_components_, max_subplots)
        pick_groups = np.split(range(ica.n_components_), cut_points)
    else:
        pick_groups = [_picks_to_idx(ica.n_components_, picks, picks_on="components")]

    axes = axes.flatten() if isinstance(axes, np.ndarray) else axes
    for k, picks in enumerate(pick_groups):
        try:  # either an iterable, 1D numpy array or others
            _axes = axes[k * max_subplots : (k + 1) * max_subplots]
        except TypeError:  # None or Axes
            _axes = axes

        (
            data_picks,
            pos,
            merge_channels,
            names,
            ch_type,
            sphere,
            clip_origin,
        ) = _prepare_topomap_plot(ica, ch_type, sphere=sphere)
        cmap = _setup_cmap(cmap, n_axes=len(picks))
        disp_names = _prepare_sensor_names(names, show_names)
        outlines = _make_head_outlines(sphere, pos, outlines, clip_origin)

        data = np.dot(
            ica.mixing_matrix_[:, picks].T, ica.pca_components_[: ica.n_components_]
        )
        data = np.atleast_2d(data)
        data = data[:, data_picks]

        if title is None:
            title = "ICA components"
        user_passed_axes = _axes is not None
        if not user_passed_axes:
            fig, _axes, _, _ = _prepare_trellis(len(data), ncols=ncols, nrows=nrows)
            fig.suptitle(title)
        else:
            _axes = [_axes] if isinstance(_axes, Axes) else _axes
            fig = _axes[0].get_figure()

        subplot_titles = list()
        for ii, data_, ax in zip(picks, data, _axes):
            kwargs = dict(color="gray") if ii in ica.exclude else dict()
            comp_title = ica._ica_names[ii]
            if len(set(ica.get_channel_types())) > 1:
                comp_title += f" ({ch_type})"
            subplot_titles.append(ax.set_title(comp_title, fontsize=12, **kwargs))
            if merge_channels:
                data_, names_ = _merge_ch_data(data_, ch_type, copy.copy(names))
            # ↓↓↓ NOTE: we intentionally use the default norm=False here, so that
            # ↓↓↓ we get vlims that are symmetric-about-zero, even if the data for
            # ↓↓↓ a given component happens to be one-sided.
            _vlim = _setup_vmin_vmax(data_, *vlim)
            im = plot_topomap(
                data_.flatten(),
                pos,
                ch_type=ch_type,
                sensors=sensors,
                names=disp_names,
                contours=contours,
                outlines=outlines,
                sphere=sphere,
                image_interp=image_interp,
                extrapolate=extrapolate,
                border=border,
                res=res,
                size=size,
                cmap=cmap[0],
                vlim=_vlim,
                cnorm=cnorm,
                axes=ax,
                show=False,
            )[0]

            im.axes.set_label(ica._ica_names[ii])
            if colorbar:
                cbar, cax = _add_colorbar(
                    ax,
                    im,
                    cmap,
                    title="AU",
                    format_=cbar_fmt,
                    kind="ica_comp_topomap",
                    ch_type=ch_type,
                )
                cbar.ax.tick_params(labelsize=12)
                cbar.set_ticks(_vlim)
            _hide_frame(ax)
        del pos
        fig.canvas.draw()

        # add title selection interactivity
        def onclick_title(event, ica=ica, titles=subplot_titles, fig=fig):
            # check if any title was pressed
            title_pressed = None
            for title in titles:
                if title.contains(event)[0]:
                    title_pressed = title
                    break
            # title was pressed -> identify the IC
            if title_pressed is not None:
                label = title_pressed.get_text()
                ic = int(label.split(" ")[0][-3:])
                # add or remove IC from exclude depending on current state
                if ic in ica.exclude:
                    ica.exclude.remove(ic)
                    title_pressed.set_color("k")
                else:
                    ica.exclude.append(ic)
                    title_pressed.set_color("gray")
                fig.canvas.draw()

        fig.canvas.mpl_connect("button_press_event", onclick_title)

        # add plot_properties interactivity only if inst was passed
        if isinstance(inst, BaseRaw | BaseEpochs):
            topomap_args = dict(
                sensors=sensors,
                contours=contours,
                outlines=outlines,
                sphere=sphere,
                image_interp=image_interp,
                extrapolate=extrapolate,
                border=border,
                res=res,
                cmap=cmap[0],
                vmin=vlim[0],
                vmax=vlim[1],
            )

            def onclick_topo(event, ica=ica, inst=inst):
                # check which component to plot
                if event.inaxes is not None:
                    label = event.inaxes.get_label()
                    if label.startswith("ICA"):
                        ic = int(label.split(" ")[0][-3:])
                        ica.plot_properties(
                            inst,
                            picks=ic,
                            show=True,
                            plot_std=plot_std,
                            topomap_args=topomap_args,
                            image_args=image_args,
                            psd_args=psd_args,
                            reject=reject,
                        )

            fig.canvas.mpl_connect("button_press_event", onclick_topo)
        figs.append(fig)

    plt_show(show)
    return figs[0] if len(figs) == 1 else figs


@fill_doc
def plot_tfr_topomap(
    tfr,
    tmin=None,
    tmax=None,
    fmin=0.0,
    fmax=np.inf,
    *,
    ch_type=None,
    baseline=None,
    mode="mean",
    sensors=True,
    show_names=False,
    mask=None,
    mask_params=None,
    contours=6,
    outlines="head",
    sphere=None,
    image_interp=_INTERPOLATION_DEFAULT,
    extrapolate=_EXTRAPOLATE_DEFAULT,
    border=_BORDER_DEFAULT,
    res=64,
    size=2,
    cmap=None,
    vlim=(None, None),
    cnorm=None,
    colorbar=True,
    cbar_fmt="%1.1e",
    units=None,
    axes=None,
    show=True,
):
    """Plot topographic maps of specific time-frequency intervals of TFR data.

    Parameters
    ----------
    tfr : AverageTFR
        The AverageTFR object.
    %(tmin_tmax_psd)s
    %(fmin_fmax_psd)s
    %(ch_type_topomap_psd)s
    baseline : tuple or list of length 2
        The time interval to apply rescaling / baseline correction. If None do
        not apply it. If baseline is (a, b) the interval is between "a (s)" and
        "b (s)". If a is None the beginning of the data is used and if b is
        None then b is set to the end of the interval. If baseline is equal to
        (None, None) the whole time interval is used.
    mode : 'mean' | 'ratio' | 'logratio' | 'percent' | 'zscore' | 'zlogratio' | None
        Perform baseline correction by

          - subtracting the mean baseline power ('mean')
          - dividing by the mean baseline power ('ratio')
          - dividing by the mean baseline power and taking the log ('logratio')
          - subtracting the mean baseline power followed by dividing by the
            mean baseline power ('percent')
          - subtracting the mean baseline power and dividing by the standard
            deviation of the baseline power ('zscore')
          - dividing by the mean baseline power, taking the log, and dividing
            by the standard deviation of the baseline power ('zlogratio')

        If None no baseline correction is applied.
    %(sensors_topomap)s
    %(show_names_topomap)s
    %(mask_evoked_topomap)s
    %(mask_params_topomap)s
    %(contours_topomap)s
    %(outlines_topomap)s
    %(sphere_topomap_auto)s
    %(image_interp_topomap)s
    %(extrapolate_topomap)s

        .. versionchanged:: 0.21

           - The default was changed to ``'local'`` for MEG sensors.
           - ``'local'`` was changed to use a convex hull mask
           - ``'head'`` was changed to extrapolate out to the clipping circle.
    %(border_topomap)s

        .. versionadded:: 0.20
    %(res_topomap)s
    %(size_topomap)s
    %(cmap_topomap)s
    %(vlim_plot_topomap)s

        .. versionadded:: 1.2
    %(cnorm)s

        .. versionadded:: 1.2
    %(colorbar_topomap)s
    %(cbar_fmt_topomap)s
    %(units_topomap)s
    %(axes_plot_topomap)s
    %(show)s

    Returns
    -------
    fig : matplotlib.figure.Figure
        The figure containing the topography.
    """  # noqa: E501
    import matplotlib.pyplot as plt

    from ..channels.layout import _merge_ch_data

    ch_type = _get_plot_ch_type(tfr, ch_type)

    picks, pos, merge_channels, names, _, sphere, clip_origin = _prepare_topomap_plot(
        tfr, ch_type, sphere=sphere
    )
    outlines = _make_head_outlines(sphere, pos, outlines, clip_origin)
    data = tfr.data[picks]

    # merging grads before rescaling makes ERDs visible
    if merge_channels:
        data, names = _merge_ch_data(data, ch_type, names, method="mean")

    data = rescale(data, tfr.times, baseline, mode, copy=True)

    # handle unaggregated multitaper (complex or phase multitaper data)
    if tfr.weights is not None:  # assumes a taper dimension
        logger.info("Aggregating multitaper estimates before plotting...")
        weights = tfr.weights[np.newaxis, :, :, np.newaxis]  # add channel & time dims
        data = weights * data
        if np.iscomplexobj(data):  # complex coefficients → power
            data *= data.conj()
            data = data.real.sum(axis=1)
            data *= 2 / (weights * weights.conj()).real.sum(axis=1)
        else:  # tapered phase data → weighted phase data
            data = data.mean(axis=1)
    # handle remaining complex amplitude → real power
    if np.iscomplexobj(data):
        data = np.sqrt((data * data.conj()).real)

    # crop time
    itmin, itmax = None, None
    idx = np.where(_time_mask(tfr.times, tmin, tmax))[0]
    if tmin is not None:
        itmin = idx[0]
    if tmax is not None:
        itmax = idx[-1] + 1
    # crop freqs
    ifmin, ifmax = None, None
    idx = np.where(_time_mask(tfr.freqs, fmin, fmax))[0]
    ifmin = idx[0]
    ifmax = idx[-1] + 1

    data = data[:, ifmin:ifmax, itmin:itmax]
    data = data.mean(axis=(1, 2))[:, np.newaxis]
    norm = False if np.min(data) < 0 else True
    vlim = _setup_vmin_vmax(data, *vlim, norm)
    cmap = _setup_cmap(cmap, norm=norm)

    axes = (
        plt.subplots(figsize=(size, size), layout="constrained")[1]
        if axes is None
        else axes
    )
    fig = axes.figure

    _hide_frame(axes)

    locator = None
    if not isinstance(contours, list | np.ndarray):
        locator, contours = _set_contour_locator(*vlim, contours)

    fig_wrapper = list()
    selection_callback = partial(
        _onselect,
        tfr=tfr,
        pos=pos,
        ch_type=ch_type,
        itmin=itmin,
        itmax=itmax,
        ifmin=ifmin,
        ifmax=ifmax,
        cmap=cmap[0],
        fig=fig_wrapper,
    )

    if not isinstance(contours, list | np.ndarray):
        _, contours = _set_contour_locator(*vlim, contours)

    names = _prepare_sensor_names(names, show_names)

    im, _ = plot_topomap(
        data[:, 0],
        pos,
        ch_type=ch_type,
        sensors=sensors,
        names=names,
        mask=mask,
        mask_params=mask_params,
        contours=contours,
        outlines=outlines,
        sphere=sphere,
        image_interp=image_interp,
        extrapolate=extrapolate,
        border=border,
        res=res,
        size=size,
        cmap=cmap[0],
        vlim=vlim,
        cnorm=cnorm,
        axes=axes,
        show=False,
        onselect=selection_callback,
    )

    if colorbar:
        from matplotlib import ticker

        units = _handle_default("units", units)["misc"]
        cbar, cax = _add_colorbar(
            axes,
            im,
            cmap,
            title=units,
            format_=cbar_fmt,
            kind="tfr_topomap",
            ch_type=ch_type,
        )
        if locator is None:
            locator = ticker.MaxNLocator(nbins=5)
        cbar.locator = locator
        cbar.update_ticks()
        cbar.ax.tick_params(labelsize=12)

    plt_show(show)
    return fig


@fill_doc
def plot_evoked_topomap(
    evoked,
    times="auto",
    *,
    average=None,
    ch_type=None,
    scalings=None,
    proj=False,
    sensors=True,
    show_names=False,
    mask=None,
    mask_params=None,
    contours=6,
    outlines="head",
    sphere=None,
    image_interp=_INTERPOLATION_DEFAULT,
    extrapolate=_EXTRAPOLATE_DEFAULT,
    border=_BORDER_DEFAULT,
    res=64,
    size=1,
    cmap=None,
    vlim=(None, None),
    cnorm=None,
    colorbar=True,
    cbar_fmt="%3.1f",
    units=None,
    axes=None,
    time_unit="s",
    time_format=None,
    nrows=1,
    ncols="auto",
    show=True,
):
    """Plot topographic maps of specific time points of evoked data.

    Parameters
    ----------
    evoked : Evoked
        The Evoked object.
    times : float | array of float | "auto" | "peaks" | "interactive"
        The time point(s) to plot. If "auto", the number of ``axes`` determines
        the amount of time point(s). If ``axes`` is also None, at most 10
        topographies will be shown with a regular time spacing between the
        first and last time instant. If "peaks", finds time points
        automatically by checking for local maxima in global field power. If
        "interactive", the time can be set interactively at run-time by using a
        slider.
    %(average_plot_evoked_topomap)s
    %(ch_type_topomap)s
    %(scalings_topomap)s
    %(proj_plot)s
    %(sensors_topomap)s
    %(show_names_topomap)s
    %(mask_evoked_topomap)s
    %(mask_params_topomap)s
    %(contours_topomap)s
    %(outlines_topomap)s
    %(sphere_topomap_auto)s
    %(image_interp_topomap)s
    %(extrapolate_topomap)s

        .. versionadded:: 0.18

        .. versionchanged:: 0.21

           - The default was changed to ``'local'`` for MEG sensors.
           - ``'local'`` was changed to use a convex hull mask
           - ``'head'`` was changed to extrapolate out to the clipping circle.
    %(border_topomap)s

        .. versionadded:: 0.20
    %(res_topomap)s
    %(size_topomap)s
    %(cmap_topomap)s
    %(vlim_plot_topomap_psd)s

        .. versionadded:: 1.2
    %(cnorm)s

        .. versionadded:: 1.2
    %(colorbar_topomap)s
    %(cbar_fmt_topomap)s
    %(units_topomap_evoked)s
    %(axes_evoked_plot_topomap)s
    time_unit : str
        The units for the time axis, can be "ms" or "s" (default).

        .. versionadded:: 0.16
    time_format : str | None
        String format for topomap values. Defaults (None) to "%%01d ms" if
        ``time_unit='ms'``, "%%0.3f s" if ``time_unit='s'``, and
        "%%g" otherwise. Can be an empty string to omit the time label.
    %(nrows_ncols_topomap)s Ignored when times == 'interactive'.

        .. versionadded:: 0.20
    %(show)s

    Returns
    -------
    fig : instance of matplotlib.figure.Figure
       The figure.

    Notes
    -----
    When existing ``axes`` are provided and ``colorbar=True``, note that the
    colorbar scale will only accurately reflect topomaps that are generated in
    the same call as the colorbar. Note also that the colorbar will not be
    resized automatically when ``axes`` are provided; use Matplotlib's
    :meth:`axes.set_position() <matplotlib.axes.Axes.set_position>` method or
    :ref:`gridspec <matplotlib:arranging_axes>` interface to adjust the colorbar
    size yourself.

    The defaults for ``contours`` and ``vlim`` are handled as follows:

    * When neither ``vlim`` nor a list of ``contours`` is passed, MNE sets
      ``vlim`` at ± the maximum absolute value of the data and then chooses
      contours within those bounds.

    * When ``vlim`` but not a list of ``contours`` is passed, MNE chooses
      contours to be within the ``vlim``.

    * When a list of ``contours`` but not ``vlim`` is passed, MNE chooses
      ``vlim`` to encompass the ``contours`` and the maximum absolute value of the
      data.

    * When both a list of ``contours`` and ``vlim`` are passed, MNE uses them
      as-is.

    When ``time=="interactive"``, the figure will publish and subscribe to the
    following UI events:

    * :class:`~mne.viz.ui_events.TimeChange` whenever a new time is selected.
    """
    import matplotlib.pyplot as plt
    from matplotlib.gridspec import GridSpec
    from matplotlib.widgets import Slider

    from ..channels.layout import _merge_ch_data
    from ..evoked import Evoked

    _validate_type(evoked, Evoked, "evoked")
    _validate_type(colorbar, bool, "colorbar")
    evoked = evoked.copy()  # make a copy, since we'll be picking
    ch_type = _get_plot_ch_type(evoked, ch_type)
    # time units / formatting
    time_unit, _ = _check_time_unit(time_unit, evoked.times)
    scaling_time = 1.0 if time_unit == "s" else 1e3
    _validate_type(time_format, (None, str), "time_format")
    if time_format is None:
        time_format = "%0.3f s" if time_unit == "s" else "%01d ms"
    del time_unit
    # mask_params defaults
    mask_params = _handle_default("mask_params", mask_params)
    mask_params["markersize"] *= size / 2.0
    mask_params["markeredgewidth"] *= size / 2.0
    # setup various parameters, and prepare outlines
    (
        picks,
        pos,
        merge_channels,
        names,
        ch_type,
        sphere,
        clip_origin,
    ) = _prepare_topomap_plot(evoked, ch_type, sphere=sphere)
    outlines = _make_head_outlines(sphere, pos, outlines, clip_origin)
    # check interactive
    axes_given = axes is not None
    interactive = isinstance(times, str) and times == "interactive"
    if interactive and axes_given:
        raise ValueError("User-provided axes not allowed when times='interactive'.")
    # units, scalings
    key = "grad" if ch_type.startswith("planar") else ch_type
    default_scaling = _handle_default("scalings", None)[key]
    scaling = _handle_default("scalings", scalings)[key]
    # if non-default scaling, fall back to "AU" if unit wasn't given by user
    key = "misc" if scaling != default_scaling else key
    unit = _handle_default("units", units)[key]
    # ch_names (required for NIRS)
    ch_names = names
    names = _prepare_sensor_names(names, show_names)
    # apply projections before picking. NOTE: the `if proj is True`
    # anti-pattern is needed here to exclude proj='interactive'
    _check_option("proj", proj, (True, False, "interactive", "reconstruct"))
    if proj is True and not evoked.proj:
        evoked.apply_proj()
    elif proj == "reconstruct":
        evoked._reconstruct_proj()
    data = evoked.data

    # remove compensation matrices (safe: only plotting & already made copy)
    with evoked.info._unlock():
        evoked.info["comps"] = []
    evoked = evoked._pick_drop_channels(picks, verbose=False)
    # determine which times to plot
    if isinstance(axes, plt.Axes):
        axes = [axes]
    n_peaks = len(axes) - int(colorbar) if axes_given else None
    times = _process_times(evoked, times, n_peaks)
    n_times = len(times)
    space = 1 / (2.0 * evoked.info["sfreq"])
    if max(times) > max(evoked.times) + space or min(times) < min(evoked.times) - space:
        raise ValueError(
            f"Times should be between {evoked.times[0]:0.3} and {evoked.times[-1]:0.3}."
        )
    # create axes
    want_axes = n_times + int(colorbar)
    if interactive:
        height_ratios = [5, 1]
        nrows = 2
        ncols = n_times
        width = size * want_axes
        height = size + max(0, 0.1 * (4 - size))
        fig = figure_nobar(figsize=(width * 1.5, height * 1.5))
        gs = GridSpec(nrows, ncols, height_ratios=height_ratios, figure=fig)
        axes = []
        for ax_idx in range(n_times):
            axes.append(plt.subplot(gs[0, ax_idx]))
    elif axes is None:
        fig, axes, ncols, nrows = _prepare_trellis(
            n_times, ncols=ncols, nrows=nrows, size=size
        )
    else:
        nrows, ncols = None, None  # Deactivate ncols when axes were passed
        fig = axes[0].get_figure()
        # check: enough space for colorbar?
        if len(axes) != want_axes:
            cbar_err = " plus one for the colorbar" if colorbar else ""
            raise RuntimeError(
                f"You must provide {want_axes} axes (one for "
                f"each time{cbar_err}), got {len(axes)}."
            )
    del want_axes
    # find first index that's >= (to rounding error) to each time point
    time_idx = [
        np.where(
            _time_mask(evoked.times, tmin=t, tmax=None, sfreq=evoked.info["sfreq"])
        )[0][0]
        for t in times
    ]
    # do averaging if requested
    avg_err = (
        '"average" must be `None`, a positive number of seconds, or '
        "an array-like object of the previous"
    )

    averaged_times = []
    if average is None:
        average = np.array([None] * n_times)
        data = data[np.ix_(picks, time_idx)]
    else:
        if _is_numeric(average):
            average = np.array([average] * n_times)
        elif np.array(average).ndim == 0:
            # It should be an array-like object
            raise TypeError(f"{avg_err}; got type: {type(average)}.")
        else:
            average = np.array(average)

        if len(average) != n_times:
            raise ValueError(
                f"You requested to plot topographic maps for {n_times} time "
                f"points, but provided {len(average)} periods for "
                f"averaging. The number of time points and averaging periods "
                f"must be equal."
            )
        data_ = np.zeros((len(picks), len(time_idx)))

        for average_idx, (this_average, this_time, this_time_idx) in enumerate(
            zip(average, evoked.times[time_idx], time_idx)
        ):
            if (_is_numeric(this_average) and this_average <= 0) or (
                not _is_numeric(this_average) and this_average is not None
            ):
                if len(average) == 1:
                    msg = f"{avg_err}; got {this_average}"
                else:
                    msg = f"{avg_err}; got {this_average} in {average}"
                raise ValueError(msg)

            if this_average is None:
                data_[:, average_idx] = data[picks][:, this_time_idx]
                averaged_times.append([this_time])
            else:
                tmin_ = this_time - this_average / 2
                tmax_ = this_time + this_average / 2
                time_mask = (tmin_ < evoked.times) & (evoked.times < tmax_)
                data_[:, average_idx] = data[picks][:, time_mask].mean(-1)
                averaged_times.append(evoked.times[time_mask])
        data = data_

    # apply scalings and merge channels
    data *= scaling
    if merge_channels:
        # check modality
        if any(ch["coil_type"] in _opm_coils for ch in evoked.info["chs"]):
            modality = "opm"
        elif ch_type in _fnirs_types:
            modality = "fnirs"
        else:
            modality = "other"
        # merge data
        data, ch_names = _merge_ch_data(data, ch_type, ch_names, modality=modality)
        # if ch_type in _fnirs_types:
        if modality != "other":
            merge_channels = False
    # apply mask if requested
    if mask is not None:
        mask = mask.astype(bool, copy=False)
        if ch_type == "grad":
            mask_ = (
                mask[np.ix_(picks[::2], time_idx)] | mask[np.ix_(picks[1::2], time_idx)]
            )
        else:  # mag, eeg, planar1, planar2
            mask_ = mask[np.ix_(picks, time_idx)]
    # set up colormap
    _vlim = [
        _setup_vmin_vmax(data[:, i], *vlim, norm=merge_channels) for i in range(n_times)
    ]
    _vlim = [np.min(_vlim), np.max(_vlim)]
    cmap = _setup_cmap(cmap, n_axes=n_times, norm=_vlim[0] >= 0)
    # set up contours
    if not isinstance(contours, list | np.ndarray):
        _, contours = _set_contour_locator(*_vlim, contours)
    else:
        if vlim[0] is None and np.any(contours < _vlim[0]):
            _vlim[0] = contours[0]
        if vlim[1] is None and np.any(contours > _vlim[1]):
            _vlim[1] = contours[-1]

    # prepare for main loop over times
    kwargs = dict(
        sensors=sensors,
        res=res,
        names=names,
        cmap=cmap[0],
        cnorm=cnorm,
        mask_params=mask_params,
        outlines=outlines,
        contours=contours,
        image_interp=image_interp,
        show=False,
        extrapolate=extrapolate,
        sphere=sphere,
        border=border,
        ch_type=ch_type,
    )
    images, contours_ = [], []
    # loop over times
    for average_idx, (time, this_average) in enumerate(zip(times, average)):
        tp, cn, interp = _plot_topomap(
            data[:, average_idx],
            pos,
            axes=axes[average_idx],
            mask=mask_[:, average_idx] if mask is not None else None,
            vmin=_vlim[0],
            vmax=_vlim[1],
            **kwargs,
        )

        images.append(tp)
        if cn is not None:
            contours_.append(cn)
        if time_format != "":
            if this_average is None:
                axes_title = time_format % (time * scaling_time)
            else:
                tmin_ = averaged_times[average_idx][0]
                tmax_ = averaged_times[average_idx][-1]
                from_time = time_format % (tmin_ * scaling_time)
                from_time = from_time.split(" ")[0]  # Remove unit
                to_time = time_format % (tmax_ * scaling_time)
                axes_title = f"{from_time} – {to_time}"
                del from_time, to_time, tmin_, tmax_
            axes[average_idx].set_title(axes_title)

    if interactive:
        # Add a slider to the figure and start publishing and subscribing to time_change
        # events.
        kwargs.update(vlim=_vlim)
        axes.append(fig.add_subplot(gs[1]))
        slider = Slider(
            axes[-1],
            "Time",
            evoked.times[0],
            evoked.times[-1],
            valinit=times[0],
            valfmt="%1.2fs",
        )
        slider.vline.remove()  # remove initial point indicator
        func = _merge_ch_data if merge_channels else lambda x: x

        def _slider_changed(val):
            publish(fig, TimeChange(time=val))

        slider.on_changed(_slider_changed)
        ts = np.tile(evoked.times, len(evoked.data)).reshape(evoked.data.shape)
        axes[-1].plot(ts, evoked.data, color="k")
        axes[-1].slider = slider

        subscribe(
            fig,
            "time_change",
            partial(
                _on_time_change,
                fig=fig,
                data=evoked.data,
                times=evoked.times,
                pos=pos,
                scaling=scaling,
                func=func,
                time_format=time_format,
                scaling_time=scaling_time,
                slider=slider,
                kwargs=kwargs,
            ),
        )
        subscribe(
            fig,
            "colormap_range",
            partial(_on_colormap_range, kwargs=kwargs),
        )

    if colorbar:
        if nrows is None or ncols is None:
            # axes were given by the user, so don't resize the colorbar
            cax = axes[-1]
        else:  # use the default behavior
            cax = None

        cbar = fig.colorbar(images[-1], ax=axes, cax=cax, format=cbar_fmt, shrink=0.6)
        if unit is not None:
            cbar.ax.set_title(unit)
        if cn is not None:
            cbar.set_ticks(contours)
        cbar.ax.tick_params(labelsize=7)
        if cmap[1]:
            for im in images:
                im.axes.CB = DraggableColorbar(
                    cbar, im, kind="evoked_topomap", ch_type=ch_type
                )

    if proj == "interactive":
        _check_delayed_ssp(evoked)
        params = dict(
            evoked=evoked,
            fig=fig,
            projs=evoked.info["projs"],
            picks=picks,
            images=images,
            contours_=contours_,
            pos=pos,
            time_idx=time_idx,
            res=res,
            plot_update_proj_callback=_plot_update_evoked_topomap,
            merge_channels=merge_channels,
            scale=scaling,
            axes=axes[: len(axes) - bool(interactive)],
            contours=contours,
            interp=interp,
            extrapolate=extrapolate,
        )
        _draw_proj_checkbox(None, params)
        # This is mostly for testing purposes, but it's also consistent with
        # raw.plot, so maybe not a bad thing in principle either
        from mne.viz._figure import BrowserParams

        fig.mne = BrowserParams(proj_checkboxes=params["proj_checks"])

    plt_show(show, block=False)
    if axes_given:
        fig.canvas.draw()
    return fig


def _resize_cbar(cax, n_fig_axes, size=1):
    """Resize colorbar."""
    cpos = cax.get_position()
    if size <= 1:
        cpos.x0 = 1 - (0.7 + 0.1 / size) / n_fig_axes
    cpos.x1 = cpos.x0 + 0.1 / n_fig_axes
    cpos.y0 = 0.2
    cpos.y1 = 0.7
    cax.set_position(cpos)


def _on_time_change(
    event,
    fig,
    data,
    times,
    pos,
    scaling,
    func,
    time_format,
    scaling_time,
    slider,
    kwargs,
):
    """Handle updating topomap to show a new time."""
    idx = np.argmin(np.abs(times - event.time))
    data = func(data[:, idx]).ravel() * scaling
    ax = fig.axes[0]
    ax.clear()
    im, _ = plot_topomap(data, pos, axes=ax, **kwargs)
    if hasattr(ax, "CB"):
        ax.CB.mappable = im
        _resize_cbar(ax.CB.cbar.ax, 2)
    if time_format is not None:
        ax.set_title(time_format % (event.time * scaling_time))
    # Updating the slider will generate a new time_change event. To prevent an
    # infinite loop, only update the slider if the time has actually changed.
    if event.time != slider.val:
        slider.set_val(event.time)
    ax.figure.canvas.draw_idle()


def _on_colormap_range(event, kwargs):
    """Handle updating colormap range."""
    logger.debug(f"Updating colormap range to {event.fmin}, {event.fmax}")
    kwargs.update(vlim=(event.fmin, event.fmax), cmap=event.cmap)


def _plot_topomap_multi_cbar(
    data,
    pos,
    ax,
    *,
    vlim,
    title,
    unit,
    cmap,
    outlines,
    colorbar,
    cbar_fmt,
    sphere,
    ch_type,
    sensors,
    names,
    mask,
    mask_params,
    contours,
    image_interp,
    extrapolate,
    border,
    res,
    size,
    cnorm,
):
    _hide_frame(ax)
    _vlim = (
        np.min(data) if vlim[0] is None else vlim[0],
        np.max(data) if vlim[1] is None else vlim[1],
    )
    # this definition of "norm" allows non-diverging colormap for cases
    # where min & vmax are both negative (e.g., when they are power in dB)
    signs = np.sign(_vlim)
    norm = len(set(signs)) == 1 or np.any(signs == 0)

    _cmap = _setup_cmap(cmap, norm=norm)
    if title is not None:
        ax.set_title(title, fontsize=10)
    im, _ = plot_topomap(
        data,
        pos,
        ch_type=ch_type,
        sensors=sensors,
        names=names,
        mask=mask,
        mask_params=mask_params,
        contours=contours,
        outlines=outlines,
        sphere=sphere,
        image_interp=image_interp,
        extrapolate=extrapolate,
        border=border,
        res=res,
        size=size,
        cmap=_cmap[0],
        vlim=_vlim,
        cnorm=cnorm,
        axes=ax,
        show=False,
        onselect=None,
    )

    if colorbar:
        cbar, cax = _add_colorbar(ax, im, cmap, title=None, format_=cbar_fmt)
        cbar.set_ticks(_vlim)
        if unit is not None:
            cbar.ax.set_ylabel(unit, fontsize=8)
        cbar.ax.tick_params(labelsize=8)


@legacy(alt="Epochs.compute_psd().plot_topomap()")
@verbose
def plot_epochs_psd_topomap(
    epochs,
    bands=None,
    tmin=None,
    tmax=None,
    proj=False,
    *,
    bandwidth=None,
    adaptive=False,
    low_bias=True,
    normalization="length",
    ch_type=None,
    normalize=False,
    agg_fun=None,
    dB=False,
    sensors=True,
    names=None,
    mask=None,
    mask_params=None,
    contours=0,
    outlines="head",
    sphere=None,
    image_interp=_INTERPOLATION_DEFAULT,
    extrapolate=_EXTRAPOLATE_DEFAULT,
    border=_BORDER_DEFAULT,
    res=64,
    size=1,
    cmap=None,
    vlim=(None, None),
    cnorm=None,
    colorbar=True,
    cbar_fmt="auto",
    units=None,
    axes=None,
    show=True,
    n_jobs=None,
    verbose=None,
):
    """Plot the topomap of the power spectral density across epochs.

    Parameters
    ----------
    epochs : instance of Epochs
        The epochs object.
    %(bands_psd_topo)s
    %(tmin_tmax_psd)s
    %(proj_psd)s
    bandwidth : float
        The bandwidth of the multi taper windowing function in Hz. The default
        value is a window half-bandwidth of 4 Hz.
    adaptive : bool
        Use adaptive weights to combine the tapered spectra into PSD
        (slow, use n_jobs >> 1 to speed up computation).
    low_bias : bool
        Only use tapers with more than 90%% spectral concentration within
        bandwidth.
    %(normalization)s
    %(ch_type_topomap_psd)s
    %(normalize_psd_topo)s
    %(agg_fun_psd_topo)s
    %(dB_plot_topomap)s
    %(sensors_topomap)s
    %(names_topomap)s
    %(mask_evoked_topomap)s
    %(mask_params_topomap)s
    %(contours_topomap)s
    %(outlines_topomap)s
    %(sphere_topomap_auto)s
    %(image_interp_topomap)s
    %(extrapolate_topomap)s

        .. versionchanged:: 0.21

           - The default was changed to ``'local'`` for MEG sensors.
           - ``'local'`` was changed to use a convex hull mask
           - ``'head'`` was changed to extrapolate out to the clipping circle.
    %(border_topomap)s

        .. versionadded:: 0.20
    %(res_topomap)s
    %(size_topomap)s
    %(cmap_topomap)s
    %(vlim_plot_topomap_psd)s

        .. versionadded:: 0.21
    %(cnorm)s

        .. versionadded:: 1.2
    %(colorbar_topomap)s
    %(cbar_fmt_topomap_psd)s
    %(units_topomap)s
    %(axes_spectrum_plot_topomap)s
    %(show)s
    %(n_jobs)s
    %(verbose)s

    Returns
    -------
    fig : instance of Figure
        Figure showing one scalp topography per frequency band.
    """
    from ..channels import rename_channels
    from ..time_frequency import Spectrum

    init_kw, plot_kw = _split_psd_kwargs(plot_fun=Spectrum.plot_topomap)
    spectrum = epochs.compute_psd(**init_kw)
    plot_kw.setdefault("show_names", False)
    if names is not None:
        rename_channels(
            spectrum.info, dict(zip(spectrum.ch_names, names)), verbose=verbose
        )
        plot_kw["show_names"] = True
    return spectrum.plot_topomap(**plot_kw)


@fill_doc
def plot_psds_topomap(
    psds,
    freqs,
    pos,
    *,
    bands=None,
    ch_type="eeg",
    normalize=False,
    agg_fun=None,
    dB=True,
    sensors=True,
    names=None,
    mask=None,
    mask_params=None,
    contours=0,
    outlines="head",
    sphere=None,
    image_interp=_INTERPOLATION_DEFAULT,
    extrapolate=_EXTRAPOLATE_DEFAULT,
    border=_BORDER_DEFAULT,
    res=64,
    size=1,
    cmap=None,
    vlim=(None, None),
    cnorm=None,
    colorbar=True,
    cbar_fmt="auto",
    unit=None,
    axes=None,
    show=True,
):
    """Plot spatial maps of PSDs.

    Parameters
    ----------
    psds : array of float, shape (n_channels, n_freqs)
        Power spectral densities.
    freqs : array of float, shape (n_freqs,)
        Frequencies used to compute psds.
    %(pos_topomap_psd)s
    %(bands_psd_topo)s
    %(ch_type_topomap)s
    %(normalize_psd_topo)s
    %(agg_fun_psd_topo)s
    %(dB_plot_topomap)s
    %(sensors_topomap)s
    %(names_topomap)s
    %(mask_evoked_topomap)s
    %(mask_params_topomap)s
    %(contours_topomap)s
    %(outlines_topomap)s
    %(sphere_topomap_auto)s
    %(image_interp_topomap)s
    %(extrapolate_topomap)s

        .. versionchanged:: 0.21

           - The default was changed to ``'local'`` for MEG sensors.
           - ``'local'`` was changed to use a convex hull mask
           - ``'head'`` was changed to extrapolate out to the clipping circle.
    %(border_topomap)s

        .. versionadded:: 0.20
    %(res_topomap)s
    %(size_topomap)s
    %(cmap_topomap)s
    %(vlim_plot_topomap_psd)s

        .. versionadded:: 0.21
    %(cnorm)s

        .. versionadded:: 1.2
    %(colorbar_topomap)s
    %(cbar_fmt_topomap_psd)s
    unit : str | None
        Measurement unit to be displayed with the colorbar. If ``None``, no
        unit is displayed (only "power" or "dB" as appropriate).
    %(axes_spectrum_plot_topomap)s
    %(show)s

    Returns
    -------
    fig : instance of matplotlib.figure.Figure
        Figure with a topomap subplot for each band.
    """
    import matplotlib.pyplot as plt
    from matplotlib.axes import Axes

    # handle some defaults
    sphere = _check_sphere(sphere)
    if cbar_fmt == "auto":
        cbar_fmt = "%0.1f" if dB else "%0.3f"
    # make sure `bands` is a dict
    if bands is None:
        bands = {
            "Delta (0-4 Hz)": (0, 4),
            "Theta (4-8 Hz)": (4, 8),
            "Alpha (8-12 Hz)": (8, 12),
            "Beta (12-30 Hz)": (12, 30),
            "Gamma (30-45 Hz)": (30, 45),
        }
    elif not hasattr(bands, "keys"):
        # convert legacy list-of-tuple input to a dict
        bands = {band[-1]: band[:-1] for band in bands}
        logger.info(
            "converting legacy list-of-tuples input to a dict for the `bands` parameter"
        )
    # upconvert single freqs to band upper/lower edges as needed
    bin_spacing = np.diff(freqs)[0]
    bin_edges = np.array([0, bin_spacing]) - bin_spacing / 2
    for band, _edges in bands.items():
        if not hasattr(_edges, "__len__"):
            _edges = (_edges,)
        if len(_edges) == 1:
            bands[band] = tuple(bin_edges + freqs[np.argmin(np.abs(freqs - _edges[0]))])
    # normalize data (if requested)
    if normalize:
        psds /= psds.sum(axis=-1, keepdims=True)
        assert np.allclose(psds.sum(axis=-1), 1.0)
    # aggregate within bands
    if agg_fun is None:
        agg_fun = np.sum if normalize else np.mean
    freq_masks = list()
    for band, (fmin, fmax) in bands.items():
        _mask = (fmin < freqs) & (freqs < fmax)
        # make sure no bands are empty
        if _mask.sum() == 0:
            raise RuntimeError(f'No frequencies in band "{band}" ({fmin}, {fmax})')
        freq_masks.append(_mask)
    band_data = [agg_fun(psds[:, _mask], axis=1) for _mask in freq_masks]
    if dB and not normalize:
        band_data = [10 * np.log10(_d) for _d in band_data]
    # handle vmin/vmax
    joint_vlim = vlim == "joint"
    if joint_vlim:
        vlim = (np.array(band_data).min(), np.array(band_data).max())
    # unit label
    if unit is None:
        unit = "dB" if dB and not normalize else "power"
    else:
        _dB = dB and not normalize
        unit = _format_units_psd(unit, dB=_dB)
    # set up figure / axes
    n_axes = len(bands)
    user_passed_axes = axes is not None
    if user_passed_axes:
        if isinstance(axes, Axes):
            axes = [axes]
        _validate_if_list_of_axes(axes, n_axes)
        fig = axes[0].figure
    else:
        fig, axes = plt.subplots(
            1, n_axes, figsize=(2 * n_axes, 1.5), layout="constrained"
        )
        if n_axes == 1:
            axes = [axes]
    # loop over subplots/frequency bands
    for ax, _mask, _data, (title, (fmin, fmax)) in zip(
        axes, freq_masks, band_data, bands.items()
    ):
        plot_colorbar = False if not colorbar else (not joint_vlim) or ax == axes[-1]
        _plot_topomap_multi_cbar(
            _data,
            pos,
            ax,
            title=title,
            vlim=vlim,
            cmap=cmap,
            outlines=outlines,
            colorbar=plot_colorbar,
            unit=unit,
            cbar_fmt=cbar_fmt,
            sphere=sphere,
            ch_type=ch_type,
            sensors=sensors,
            names=names,
            mask=mask,
            mask_params=mask_params,
            contours=contours,
            image_interp=image_interp,
            extrapolate=extrapolate,
            border=border,
            res=res,
            size=size,
            cnorm=cnorm,
        )

    if not user_passed_axes:
        fig.canvas.draw()
        plt_show(show)
    return fig


@fill_doc
def plot_layout(layout, picks=None, show_axes=False, show=True):
    """Plot the sensor positions.

    Parameters
    ----------
    layout : None | Layout
        Layout instance specifying sensor positions.
    %(picks_layout)s
    show_axes : bool
            Show layout axes if True. Defaults to False.
    show : bool
        Show figure if True. Defaults to True.

    Returns
    -------
    fig : instance of Figure
        Figure containing the sensor topography.

    Notes
    -----
    .. versionadded:: 0.12.0
    """
    import matplotlib.pyplot as plt

    fig = plt.figure(
        figsize=(max(plt.rcParams["figure.figsize"]),) * 2, layout="constrained"
    )
    ax = fig.add_subplot(111)
    ax.set(xticks=[], yticks=[], aspect="equal")
    outlines = dict(border=([0, 1, 1, 0, 0], [0, 0, 1, 1, 0]))
    _draw_outlines(ax, outlines)
    layout = layout.copy().pick(picks)
    for ii, (p, ch_id) in enumerate(zip(layout.pos, layout.names)):
        center_pos = np.array((p[0] + p[2] / 2.0, p[1] + p[3] / 2.0))
        ax.annotate(
            ch_id,
            xy=center_pos,
            horizontalalignment="center",
            verticalalignment="center",
            size="x-small",
        )
        if show_axes:
            x1, x2, y1, y2 = p[0], p[0] + p[2], p[1], p[1] + p[3]
            ax.plot([x1, x1, x2, x2, x1], [y1, y2, y2, y1, y1], color="k")
    ax.axis("off")
    plt_show(show)
    return fig


def _onselect(
    eclick,
    erelease,
    tfr,
    pos,
    ch_type,
    itmin,
    itmax,
    ifmin,
    ifmax,
    cmap,
    fig,
    layout=None,
):
    """Handle drawing average tfr over channels called from topomap."""
    import matplotlib.pyplot as plt
    from matplotlib.collections import PathCollection

    from ..channels.layout import _pair_grad_sensors

    ax = eclick.inaxes
    xmin = min(eclick.xdata, erelease.xdata)
    xmax = max(eclick.xdata, erelease.xdata)
    ymin = min(eclick.ydata, erelease.ydata)
    ymax = max(eclick.ydata, erelease.ydata)
    indices = (
        (pos[:, 0] < xmax)
        & (pos[:, 0] > xmin)
        & (pos[:, 1] < ymax)
        & (pos[:, 1] > ymin)
    )
    colors = ["r" if ii else "k" for ii in indices]
    indices = np.where(indices)[0]
    for collection in ax.collections:
        if isinstance(collection, PathCollection):  # this is our "scatter"
            collection.set_color(colors)
    ax.figure.canvas.draw()
    if len(indices) == 0:
        return
    data = tfr.data
    if ch_type == "mag":
        picks = pick_types(tfr.info, meg=ch_type, ref_meg=False)
        data = np.mean(data[indices, ifmin:ifmax, itmin:itmax], axis=0)
        chs = [tfr.ch_names[picks[x]] for x in indices]
    elif ch_type == "grad":
        grads = _pair_grad_sensors(tfr.info, layout=layout, topomap_coords=False)
        idxs = list()
        for idx in indices:
            idxs.append(grads[idx * 2])
            idxs.append(grads[idx * 2 + 1])  # pair of grads
        data = np.mean(data[idxs, ifmin:ifmax, itmin:itmax], axis=0)
        chs = [tfr.ch_names[x] for x in idxs]
    elif ch_type == "eeg":
        picks = pick_types(tfr.info, meg=False, eeg=True, ref_meg=False)
        data = np.mean(data[indices, ifmin:ifmax, itmin:itmax], axis=0)
        chs = [tfr.ch_names[picks[x]] for x in indices]
    logger.info("Averaging TFR over channels " + str(chs))
    if len(fig) == 0:
        fig.append(figure_nobar())
    if not plt.fignum_exists(fig[0].number):
        fig[0] = figure_nobar()
    ax = fig[0].add_subplot(111)
    itmax = len(tfr.times) - 1 if itmax is None else min(itmax, len(tfr.times) - 1)
    ifmax = len(tfr.freqs) - 1 if ifmax is None else min(ifmax, len(tfr.freqs) - 1)
    if itmin is None:
        itmin = 0
    if ifmin is None:
        ifmin = 0
    extent = (
        tfr.times[itmin] * 1e3,
        tfr.times[itmax] * 1e3,
        tfr.freqs[ifmin],
        tfr.freqs[ifmax],
    )

    title = f"Average over {len(chs)} {ch_type} channels."
    ax.set_title(title)
    ax.set_xlabel("Time (ms)")
    ax.set_ylabel("Frequency (Hz)")
    img = ax.imshow(data, extent=extent, aspect="auto", origin="lower", cmap=cmap)
    if len(fig[0].get_axes()) < 2:
        fig[0].get_axes()[1].cbar = fig[0].colorbar(mappable=img)
    else:
        fig[0].get_axes()[1].cbar.on_mappable_changed(mappable=img)
    fig[0].canvas.draw()
    plt.figure(fig[0].number)
    plt_show(True)


def _prepare_topomap(pos, ax, check_nonzero=True):
    """Prepare the topomap axis and check positions.

    Hides axis frame and check that position information is present.
    """
    _hide_frame(ax)
    if check_nonzero and not pos.any():
        raise RuntimeError(
            "No position information found, cannot compute geometries for topomap."
        )


def _hide_frame(ax):
    """Hide axis frame for topomaps."""
    ax.get_yticks()
    ax.xaxis.set_ticks([])
    ax.yaxis.set_ticks([])
    ax.set_frame_on(False)


def _check_extrapolate(extrapolate, ch_type):
    _check_option("extrapolate", extrapolate, ("box", "local", "head", "auto"))
    if extrapolate == "auto":
        extrapolate = "local" if ch_type in _MEG_CH_TYPES_SPLIT else "head"
    return extrapolate


@verbose
def _init_anim(
    ax,
    ax_line,
    ax_cbar,
    params,
    merge_channels,
    sphere,
    ch_type,
    image_interp,
    extrapolate,
    verbose,
):
    """Initialize animated topomap."""
    logger.info("Initializing animation...")
    data = params["data"]
    items = list()
    vmin = params["vmin"] if "vmin" in params else None
    vmax = params["vmax"] if "vmax" in params else None
    if params["butterfly"]:
        all_times = params["all_times"]
        for idx in range(len(data)):
            ax_line.plot(all_times, data[idx], color="k", lw=1)
        vmin, vmax = _setup_vmin_vmax(data, vmin, vmax)
        ax_line.set(
            yticks=np.around(np.linspace(vmin, vmax, 5), -1), xlim=all_times[[0, -1]]
        )
        params["line"] = ax_line.axvline(all_times[0], color="r")
        items.append(params["line"])
    if merge_channels:
        from mne.channels.layout import _merge_ch_data

        data, _ = _merge_ch_data(data, "grad", [])
    norm = True if np.min(data) > 0 else False
    cmap = "Reds" if norm else "RdBu_r"

    vmin, vmax = _setup_vmin_vmax(data, vmin, vmax, norm)

    outlines = _make_head_outlines(sphere, params["pos"], "head", params["clip_origin"])

    _hide_frame(ax)
    extent, Xi, Yi, interp = _setup_interp(
        pos=params["pos"],
        res=64,
        image_interp=image_interp,
        extrapolate=extrapolate,
        outlines=outlines,
        border=0,
    )

    patch_ = _get_patch(outlines, extrapolate, interp, ax)

    params["Zis"] = list()
    for frame in params["frames"]:
        params["Zis"].append(interp.set_values(data[:, frame])(Xi, Yi))
    Zi = params["Zis"][0]
    zi_min = np.nanmin(params["Zis"])
    zi_max = np.nanmax(params["Zis"])
    cont_lims = np.linspace(zi_min, zi_max, 7, endpoint=False)[1:]
    params.update(
        {
            "vmin": vmin,
            "vmax": vmax,
            "Xi": Xi,
            "Yi": Yi,
            "Zi": Zi,
            "extent": extent,
            "cmap": cmap,
            "cont_lims": cont_lims,
        }
    )
    # plot map and contour
    im = ax.imshow(
        Zi,
        cmap=cmap,
        vmin=vmin,
        vmax=vmax,
        origin="lower",
        aspect="equal",
        extent=extent,
        interpolation="bilinear",
    )
    ax.autoscale(enable=True, tight=True)
    ax.figure.colorbar(im, cax=ax_cbar)
    cont = ax.contour(Xi, Yi, Zi, levels=cont_lims, colors="k", linewidths=1)

    im.set_clip_path(patch_)
    text = ax.text(0.55, 0.95, "", transform=ax.transAxes, va="center", ha="right")
    params["text"] = text
    items.append(im)
    items.append(text)
    cont_collections = _cont_collections(cont)
    for col in cont_collections:
        col.set_clip_path(patch_)

    outlines_ = _draw_outlines(ax, outlines)

    params.update({"patch": patch_, "outlines": outlines_})
    return tuple(items) + cont_collections


def _animate(frame, ax, ax_line, params):
    """Update animated topomap."""
    if params["pause"]:
        frame = params["frame"]
    time_idx = params["frames"][frame]

    if params["time_unit"] == "ms":
        title = f"{params['times'][frame] * 1e3:6.0f} ms"
    else:
        title = f"{params['times'][frame]:6.3f} s"
    if params["blit"]:
        text = params["text"]
    else:
        ax.cla()  # Clear old contours.
        text = ax.text(0.45, 1.15, "", transform=ax.transAxes)
        for k, (x, y) in params["outlines"].items():
            if "mask" in k:
                continue
            ax.plot(x, y, color="k", linewidth=1, clip_on=False)

    _hide_frame(ax)
    text.set_text(title)

    vmin = params["vmin"]
    vmax = params["vmax"]
    Xi = params["Xi"]
    Yi = params["Yi"]
    Zi = params["Zis"][frame]
    extent = params["extent"]
    cmap = params["cmap"]
    patch = params["patch"]

    im = ax.imshow(
        Zi,
        cmap=cmap,
        vmin=vmin,
        vmax=vmax,
        origin="lower",
        aspect="equal",
        extent=extent,
        interpolation="bilinear",
    )
    cont_lims = params["cont_lims"]
    with warnings.catch_warnings(record=True):
        warnings.simplefilter("ignore")
        cont = ax.contour(Xi, Yi, Zi, levels=cont_lims, colors="k", linewidths=1)

    im.set_clip_path(patch)
    cont_collections = _cont_collections(cont)
    for col in cont_collections:
        col.set_clip_path(patch)

    items = [im, text]
    if params["butterfly"]:
        all_times = params["all_times"]
        line = params["line"]
        line.remove()
        ylim = ax_line.get_ylim()
        params["line"] = ax_line.axvline(all_times[time_idx], color="r")
        ax_line.set_ylim(ylim)
        items.append(params["line"])
    params["frame"] = frame
    return tuple(items) + cont_collections


def _pause_anim(event, params):
    """Pause or continue the animation on mouse click."""
    params["pause"] = not params["pause"]


def _key_press(event, params):
    """Handle key presses for the animation."""
    if event.key == "left":
        params["pause"] = True
        params["frame"] = max(params["frame"] - 1, 0)
    elif event.key == "right":
        params["pause"] = True
        params["frame"] = min(params["frame"] + 1, len(params["frames"]) - 1)


def _topomap_animation(
    evoked,
    ch_type,
    times,
    frame_rate,
    butterfly,
    blit,
    show,
    time_unit,
    sphere,
    image_interp,
    extrapolate,
    *,
    vmin,
    vmax,
    verbose=None,
):
    """Make animation of evoked data as topomap timeseries.

    See mne.evoked.Evoked.animate_topomap.
    """
    from matplotlib import animation
    from matplotlib import pyplot as plt

    if ch_type is None:
        ch_type = _get_plot_ch_type(evoked, ch_type)

    time_unit, _ = _check_time_unit(time_unit, evoked.times)
    if times is None:
        times = np.linspace(evoked.times[0], evoked.times[-1], 10)
    times = np.array(times)

    if times.ndim != 1:
        raise ValueError(f"times must be 1D, got {times.ndim} dimensions")
    if max(times) > evoked.times[-1] or min(times) < evoked.times[0]:
        raise ValueError("All times must be inside the evoked time series.")
    frames = [np.abs(evoked.times - time).argmin() for time in times]

    picks, pos, merge_channels, _, ch_type, sphere, clip_origin = _prepare_topomap_plot(
        evoked, ch_type, sphere=sphere
    )
    data = evoked.data[picks, :]
    data *= _handle_default("scalings")[ch_type]

    norm = np.min(data) >= 0
    vmin, vmax = _setup_vmin_vmax(data, vmin, vmax, norm)

    fig = plt.figure(figsize=(6, 5), layout="constrained")
    shape = (8, 12)
    colspan = shape[1] - 1
    rowspan = shape[0] - bool(butterfly)
    ax = plt.subplot2grid(shape, (0, 0), rowspan=rowspan, colspan=colspan)
    if butterfly:
        ax_line = plt.subplot2grid(shape, (rowspan, 0), colspan=colspan)
    else:
        ax_line = None
    if isinstance(frames, Integral):
        frames = np.linspace(0, len(evoked.times) - 1, frames).astype(int)
    ax_cbar = plt.subplot2grid(shape, (0, colspan), rowspan=rowspan)
    ax_cbar.set_title(_handle_default("units")[ch_type], fontsize=10)
    extrapolate = _check_extrapolate(extrapolate, ch_type)

    params = dict(
        data=data,
        pos=pos,
        all_times=evoked.times,
        frame=0,
        frames=frames,
        butterfly=butterfly,
        blit=blit,
        pause=False,
        times=times,
        time_unit=time_unit,
        clip_origin=clip_origin,
        vmin=vmin,
        vmax=vmax,
    )
    init_func = partial(
        _init_anim,
        ax=ax,
        ax_cbar=ax_cbar,
        ax_line=ax_line,
        params=params,
        merge_channels=merge_channels,
        sphere=sphere,
        ch_type=ch_type,
        image_interp=image_interp,
        extrapolate=extrapolate,
        verbose=verbose,
    )
    animate_func = partial(_animate, ax=ax, ax_line=ax_line, params=params)
    pause_func = partial(_pause_anim, params=params)
    fig.canvas.mpl_connect("button_press_event", pause_func)
    key_press_func = partial(_key_press, params=params)
    fig.canvas.mpl_connect("key_press_event", key_press_func)
    if frame_rate is None:
        frame_rate = evoked.info["sfreq"] / 10.0
    interval = 1000 / frame_rate  # interval is in ms
    anim = animation.FuncAnimation(
        fig,
        animate_func,
        init_func=init_func,
        frames=len(frames),
        interval=interval,
        blit=blit,
    )
    fig.mne_animation = anim  # to make sure anim is not garbage collected
    plt_show(show, block=False)
    if "line" in params:
        # Finally remove the vertical line so it does not appear in saved fig.
        params["line"].remove()

    return fig, anim


def _set_contour_locator(vmin, vmax, contours):
    """Set correct contour levels."""
    locator = None
    if isinstance(contours, Integral) and contours > 0:
        from matplotlib import ticker

        # nbins = ticks - 1, since 2 of the ticks are vmin and vmax, the
        # correct number of bins is equal to contours + 1.
        locator = ticker.MaxNLocator(nbins=contours + 1)
        contours = locator.tick_values(vmin, vmax)
        contours = contours[1:-1]
    return locator, contours


def _plot_corrmap(
    data,
    subjs,
    indices,
    ch_type,
    ica,
    label,
    *,
    show,
    outlines,
    cmap,
    contours,
    sensors=False,
    template=False,
    sphere=None,
    image_interp=_INTERPOLATION_DEFAULT,
    extrapolate=_EXTRAPOLATE_DEFAULT,
    border=_BORDER_DEFAULT,
    show_names=False,
):
    """Customize ica.plot_components for corrmap."""
    from ..channels.layout import _merge_ch_data

    if not template:
        title = "Detected components"
        if label is not None:
            title += " of type " + label
    else:
        title = "Supplied template"

    picks = list(range(len(data)))

    p = 20
    if len(picks) > p:  # plot components by sets of 20
        n_components = len(picks)
        figs = [
            _plot_corrmap(
                data[k : k + p],
                subjs[k : k + p],
                indices[k : k + p],
                ch_type,
                ica,
                label,
                show=show,
                outlines=outlines,
                cmap=cmap,
                contours=contours,
                sensors=sensors,
                image_interp=image_interp,
                extrapolate=extrapolate,
                border=border,
                show_names=show_names,
            )
            for k in range(0, n_components, p)
        ]
        return figs
    elif np.isscalar(picks):
        picks = [picks]

    (
        data_picks,
        pos,
        merge_channels,
        names,
        _,
        sphere,
        clip_origin,
    ) = _prepare_topomap_plot(ica, ch_type, sphere=sphere)
    names = _prepare_sensor_names(names, show_names)
    outlines = _make_head_outlines(sphere, pos, outlines, clip_origin)

    data = np.atleast_2d(data)
    data = data[:, data_picks]

    # prepare data for iteration
    fig, axes, _, _ = _prepare_trellis(len(picks), ncols=5)
    fig.suptitle(title)

    for ii, data_, ax, subject, idx in zip(picks, data, axes, subjs, indices):
        if template:
            ttl = f"Subj. {subject}, {ica._ica_names[idx]}"
            ax.set_title(ttl, fontsize=12)
        else:
            ax.set_title(f"Subj. {subject}")
        if merge_channels:
            data_, _ = _merge_ch_data(data_, ch_type, [])
        _vlim = _setup_vmin_vmax(data_, None, None)
        plot_topomap(
            data_.flatten(),
            pos,
            vlim=_vlim,
            names=names,
            res=64,
            axes=ax,
            cmap=cmap,
            outlines=outlines,
            contours=contours,
            show=False,
            sensors=sensors,
            image_interp=image_interp,
            extrapolate=extrapolate,
            border=border,
        )
        _hide_frame(ax)
    fig.canvas.draw()
    plt_show(show)
    return fig


def _trigradient(x, y, z):
    """Take gradients of z on a mesh."""
    from matplotlib.tri import CubicTriInterpolator, Triangulation

    tri = Triangulation(x, y)
    tci = CubicTriInterpolator(tri, z)
    dx, dy = tci.gradient(tri.x, tri.y)
    return dx, dy


@fill_doc
def plot_arrowmap(
    data,
    info_from,
    info_to=None,
    scale=3e-10,
    vlim=(None, None),
    cnorm=None,
    cmap=None,
    sensors=True,
    res=64,
    axes=None,
    show_names=False,
    mask=None,
    mask_params=None,
    outlines="head",
    contours=6,
    image_interp=_INTERPOLATION_DEFAULT,
    show=True,
    onselect=None,
    extrapolate=_EXTRAPOLATE_DEFAULT,
    sphere=None,
):
    """Plot arrow map.

    Compute arrowmaps, based upon the Hosaka-Cohen transformation
    :footcite:`CohenHosaka1976`, these arrows represents an estimation of the
    current flow underneath the MEG sensors. They are a poor man's MNE.

    Since planar gradiometers takes gradients along latitude and longitude,
    they need to be projected to the flattened manifold span by magnetometer
    or radial gradiometers before taking the gradients in the 2D Cartesian
    coordinate system for visualization on the 2D topoplot. You can use the
    ``info_from`` and ``info_to`` parameters to interpolate from
    gradiometer data to magnetometer data.

    Parameters
    ----------
    data : array, shape (n_channels,)
        The data values to plot.
    info_from : instance of Info
        The measurement info from data to interpolate from.
    info_to : instance of Info | None
        The measurement info to interpolate to. If None, it is assumed
        to be the same as info_from.
    scale : float, default 3e-10
        To scale the arrows.
    %(vlim_plot_topomap)s

        .. versionadded:: 1.2
    %(cnorm)s

        .. versionadded:: 1.2
    %(cmap_topomap_simple)s
    %(sensors_topomap)s
    %(res_topomap)s
    %(axes_plot_topomap)s
    %(show_names_topomap)s
        If ``True``, a list of names must be provided (see ``names`` keyword).
    %(mask_topomap)s
    %(mask_params_topomap)s
    %(outlines_topomap)s
    %(contours_topomap)s
    %(image_interp_topomap)s
    %(show)s
    onselect : callable | None
        Handle for a function that is called when the user selects a set of
        channels by rectangle selection (matplotlib ``RectangleSelector``). If
        None interactive selection is disabled. Defaults to None.
    %(extrapolate_topomap)s

        .. versionadded:: 0.18

        .. versionchanged:: 0.21

           - The default was changed to ``'local'`` for MEG sensors.
           - ``'local'`` was changed to use a convex hull mask
           - ``'head'`` was changed to extrapolate out to the clipping circle.
    %(sphere_topomap_auto)s

    Returns
    -------
    fig : matplotlib.figure.Figure
        The Figure of the plot.

    Notes
    -----
    .. versionadded:: 0.17

    References
    ----------
    .. footbibliography::
    """
    from matplotlib import pyplot as plt

    from ..forward import _map_meg_or_eeg_channels

    sphere = _check_sphere(sphere, info_from)
    ch_type = _picks_by_type(info_from)

    if len(ch_type) > 1:
        raise ValueError(
            "Multiple channel types are not supported."
            "All channels must either be of type 'grad' "
            "or 'mag'."
        )
    else:
        ch_type = ch_type[0][0]

    if ch_type not in ("mag", "grad"):
        raise ValueError(
            f"Channel type '{ch_type}' not supported. Supported channel "
            "types are 'mag' and 'grad'."
        )

    if info_to is None and ch_type == "mag":
        info_to = info_from
    else:
        ch_type = _picks_by_type(info_to)
        if len(ch_type) > 1:
            raise ValueError("Multiple channel types are not supported.")
        else:
            ch_type = ch_type[0][0]

        if ch_type != "mag":
            raise ValueError(f"only 'mag' channel type is supported. Got {ch_type}")

    if info_to is not info_from:
        info_to = pick_info(info_to, pick_types(info_to, meg=True))
        info_from = pick_info(info_from, pick_types(info_from, meg=True))
        # XXX should probably support the "origin" argument
        mapping = _map_meg_or_eeg_channels(
            info_from, info_to, origin=(0.0, 0.0, 0.04), mode="accurate"
        )
        data = np.dot(mapping, data)

    _, pos, _, _, _, sphere, clip_origin = _prepare_topomap_plot(
        info_to, "mag", sphere=sphere
    )
    outlines = _make_head_outlines(sphere, pos, outlines, clip_origin)
    if axes is None:
        fig, axes = plt.subplots(layout="constrained")
    else:
        fig = axes.figure
    plot_topomap(
        data,
        pos,
        axes=axes,
        vlim=vlim,
        cmap=cmap,
        cnorm=cnorm,
        sensors=sensors,
        res=res,
        mask=mask,
        mask_params=mask_params,
        outlines=outlines,
        contours=contours,
        image_interp=image_interp,
        show=False,
        onselect=onselect,
        extrapolate=extrapolate,
        sphere=sphere,
        ch_type=ch_type,
    )
    x, y = tuple(pos.T)
    dx, dy = _trigradient(x, y, data)
    dxx = dy.data
    dyy = -dx.data
    axes.quiver(x, y, dxx, dyy, scale=scale, color="k", lw=1)
    plt_show(show)

    return fig


@fill_doc
def plot_bridged_electrodes(
    info, bridged_idx, ed_matrix, title=None, topomap_args=None
):
    """Topoplot electrode distance matrix with bridged electrodes connected.

    Parameters
    ----------
    %(info_not_none)s
    bridged_idx : list of tuple
        The indices of channels marked as bridged with each bridged
        pair stored as a tuple.
        Can be generated via
        :func:`mne.preprocessing.compute_bridged_electrodes`.
    ed_matrix : array of float, shape (n_channels, n_channels)
        The electrical distance matrix for each pair of EEG electrodes.
        Can be generated via
        :func:`mne.preprocessing.compute_bridged_electrodes`.
    title : str
        A title to add to the plot.
    topomap_args : dict | None
        Arguments to pass to :func:`mne.viz.plot_topomap`.

    Returns
    -------
    fig : instance of matplotlib.figure.Figure
        The topoplot figure handle.

    See Also
    --------
    mne.preprocessing.compute_bridged_electrodes
    """
    import matplotlib.pyplot as plt

    from ..channels.layout import _find_topomap_coords

    if topomap_args is None:
        topomap_args = dict()
    else:
        topomap_args = topomap_args.copy()  # don't change original
    picks = pick_types(info, eeg=True)
    topomap_args.setdefault("image_interp", "nearest")
    topomap_args.setdefault("cmap", "summer_r")
    topomap_args.setdefault("names", pick_info(info, picks).ch_names)
    topomap_args.setdefault("contours", False)
    sphere = topomap_args.get("sphere", _check_sphere(None))
    if "axes" not in topomap_args:
        fig, ax = plt.subplots(layout="constrained")
        topomap_args["axes"] = ax
    else:
        fig = None
    # handle colorbar here instead of in plot_topomap
    colorbar = topomap_args.pop("colorbar", True)
    if ed_matrix.shape[1:] != (picks.size, picks.size):
        raise RuntimeError(
            f"Expected {(ed_matrix.shape[0], picks.size, picks.size)} "
            f"shaped `ed_matrix`, got {ed_matrix.shape}"
        )
    # fill in lower triangular
    ed_matrix = ed_matrix.copy()
    tril_idx = np.tril_indices(picks.size)
    for epo_idx in range(ed_matrix.shape[0]):
        ed_matrix[epo_idx][tril_idx] = ed_matrix[epo_idx].T[tril_idx]
    elec_dists = np.median(np.nanmin(ed_matrix, axis=1), axis=0)

    im, cn = plot_topomap(elec_dists, pick_info(info, picks), **topomap_args)
    fig = im.figure if fig is None else fig
    # add bridged connections
    for idx0, idx1 in bridged_idx:
        pos = _find_topomap_coords(info, [idx0, idx1], sphere=sphere)
        im.axes.plot([pos[0, 0], pos[1, 0]], [pos[0, 1], pos[1, 1]], color="r")
    if title is not None:
        im.axes.set_title(title)
    if colorbar:
        cax = fig.colorbar(im, shrink=0.6)
        cax.set_label(r"Electrical Distance ($\mu$$V^2$)")
    return fig


def plot_ch_adjacency(info, adjacency, ch_names, kind="2d", edit=False):
    """Plot channel adjacency.

    Parameters
    ----------
    info : instance of Info
        Info object with channel locations.
    adjacency : array
        Array of channels x channels shape. Defines which channels are adjacent
        to each other. Note that if you edit adjacencies
        (via ``edit=True``), this array will be modified in place.
    ch_names : list of str
        Names of successive channels in the ``adjacency`` matrix.
    kind : str
        How to plot the adjacency. Can be either ``'3d'`` or ``'2d'``.
    edit : bool
        Whether to allow interactive editing of the adjacency matrix via
        clicking respective channel pairs. Once clicked, the channel is
        "activated" and turns green. Clicking on another channel adds or
        removes adjacency relation between the activated and newly clicked
        channel (depending on whether the channels are already adjacent or
        not); the newly clicked channel now becomes activated. Clicking on
        an activated channel deactivates it. Editing is currently only
        supported for ``kind='2d'``.

    Returns
    -------
    fig : Figure
        The :class:`~matplotlib.figure.Figure` instance where the channel
        adjacency is plotted.

    See Also
    --------
    mne.channels.get_builtin_ch_adjacencies
    mne.channels.read_ch_adjacency
    mne.channels.find_ch_adjacency

    Notes
    -----
    .. versionadded:: 1.1
    """
    import matplotlib as mpl
    import matplotlib.pyplot as plt

    _validate_type(info, Info, "info")
    _validate_type(adjacency, (np.ndarray, csr_array), "adjacency")
    has_sparse = isinstance(adjacency, csr_array)

    if edit and kind == "3d":
        raise ValueError("Editing a 3d adjacency plot is not supported.")

    # select relevant channels
    sel = pick_channels(info.ch_names, ch_names, ordered=True)
    info = pick_info(info, sel)

    # make sure adjacency is correct size wrt to inst:
    n_channels = len(info.ch_names)
    if adjacency.shape[0] != n_channels:
        raise ValueError(
            "``adjacency`` must have the same number of rows "
            "as the number of channels in ``info``. Found "
            f"{adjacency.shape[0]} channels for ``adjacency`` and"
            f" {n_channels} for ``inst``."
        )

    if kind == "3d":
        with plt.rc_context({"toolbar": "None"}):
            fig = plot_sensors(info, kind=kind, show=False)
        _set_3d_axes_equal(fig.axes[0])
    elif kind == "2d":
        with plt.rc_context({"toolbar": "None"}):
            fig = plot_sensors(info, kind="topomap", show=False)
        fig.axes[0].axis("equal")

    path_collection = fig.axes[0].findobj(mpl.collections.PathCollection)
    path_collection[0].set_linewidths(0.0)

    if kind == "2d":
        path_collection[0].set_alpha(0.7)
        pos = path_collection[0].get_offsets()

        # make sure nodes are on top
        path_collection[0].set_zorder(10)

        # scale node size with number of connections
        n_connections = [np.sum(adjacency[[i]]) - 1 for i in range(adjacency.shape[0])]
        node_size = [max(x, 3) ** 2.5 for x in n_connections]
        path_collection[0].set_sizes(node_size)
    else:
        # plotting channel positions via mne.viz.plot_sensors(info) and using
        # the coordinates from info['chs'][ch_idx]['loc][:3] gives different
        # positions. Also .get_offsets gives 2d projections even for 3d points
        # so we use the private _offsets3d property...
        pos = path_collection[0]._offsets3d
        pos = np.stack([pos[0].data, pos[1].data, pos[2]], axis=1)

    ax = fig.axes[0]
    lines = dict()
    n_channels = adjacency.shape[0]
    for ch_idx in range(n_channels):
        # make sure we don't repeat channels
        row = adjacency[[ch_idx], ch_idx + 1 :]
        if has_sparse:
            ch_neighbours = row.nonzero()[1]
        else:
            ch_neighbours = np.where(row)[0]

        if len(ch_neighbours) == 0:
            continue

        ch_neighbours += ch_idx + 1

        for ngb_idx in ch_neighbours:
            this_pos = pos[[ch_idx, ngb_idx], :]
            ch_pair = tuple([ch_idx, ngb_idx])
            lines[ch_pair] = ax.plot(*this_pos.T, color=(0.55, 0.55, 0.55), lw=0.75)[0]

    if edit:
        # allow interactivity in 2d plots
        highlighted = dict()
        this_onpick = partial(
            _onpick_ch_adjacency,
            axes=ax,
            positions=pos,
            highlighted=highlighted,
            line_dict=lines,
            adjacency=adjacency,
            node_size=node_size,
            path_collection=path_collection,
        )
        fig.canvas.mpl_connect("pick_event", this_onpick)

    return fig


def _onpick_ch_adjacency(
    event,
    axes=None,
    positions=None,
    highlighted=None,
    line_dict=None,
    adjacency=None,
    node_size=None,
    path_collection=None,
):
    """Handle interactivity in plot_ch_adjacency."""
    node_ind = event.ind[0]

    if node_ind in highlighted:
        # de-select node, change its color back to normal
        highlighted[node_ind].remove()
        del highlighted[node_ind]
        axes.figure.canvas.draw()
    else:
        # new node selected
        if len(highlighted) == 0:
            # no highlighted nodes yet
            size = max(node_size[node_ind] * 2, 100)
            # add current node
            dots = axes.scatter(
                *positions[node_ind, :].T, color="tab:green", s=size, zorder=15
            )
            highlighted[node_ind] = dots
            axes.figure.canvas.draw()  # make sure it renders
        else:
            # one previously highlighted - add or remove line
            key = list(highlighted.keys())[0]
            both_nodes = [key, node_ind]
            both_nodes.sort()
            both_nodes = tuple(both_nodes)

            if both_nodes in line_dict.keys():
                # remove line
                n_conn_change = -1
                line_dict[both_nodes].remove()
                # remove line_dict entry
                del line_dict[both_nodes]

                # clear adjacency matrix entry
                _set_adjacency(adjacency, both_nodes, False)
            else:
                # add line
                n_conn_change = +1
                selected_pos = positions[both_nodes, :]
                line = axes.plot(*selected_pos.T, color="tab:green")[0]
                # add line to line_dict
                line_dict[both_nodes] = line

                # modify adjacency matrix
                _set_adjacency(adjacency, both_nodes, True)

            # de-highlight previous
            highlighted[key].remove()
            del highlighted[key]

            # update node sizes
            n_connections = [
                np.sum(adjacency[[idx]]) - 1 + n_conn_change for idx in both_nodes
            ]
            for idx, n_conn in zip(both_nodes, n_connections):
                node_size[idx] = max(n_conn, 3) ** 2.5
            path_collection[0].set_sizes(node_size)

            # highlight new node
            size = max(node_size[node_ind] * 2, 100)
            dots = axes.scatter(
                *positions[node_ind, :].T, color="tab:green", s=size, zorder=15
            )
            highlighted[node_ind] = dots
            axes.figure.canvas.draw()


def _set_adjacency(adjacency, both_nodes, value):
    """Set adjacency for given node pair, caching errors for sparse arrays."""
    import warnings

    with warnings.catch_warnings(record=True):
        adjacency[both_nodes, both_nodes[::-1]] = value


@fill_doc
def plot_regression_weights(
    model,
    *,
    ch_type=None,
    sensors=True,
    show_names=False,
    mask=None,
    mask_params=None,
    contours=6,
    outlines="head",
    sphere=None,
    image_interp=_INTERPOLATION_DEFAULT,
    extrapolate=_EXTRAPOLATE_DEFAULT,
    border=_BORDER_DEFAULT,
    res=64,
    size=1,
    cmap=None,
    vlim=(None, None),
    cnorm=None,
    axes=None,
    colorbar=True,
    cbar_fmt="%1.1e",
    title=None,
    show=True,
):
    """Plot the regression weights of a fitted EOGRegression model.

    Parameters
    ----------
    model : EOGRegression
        The fitted EOGRegression model whose weights will be plotted.
    %(ch_type_topomap)s
    %(sensors_topomap)s
    %(show_names_topomap)s
    %(mask_topomap)s
    %(mask_params_topomap)s
    %(contours_topomap)s
    %(outlines_topomap)s
    %(sphere_topomap_auto)s
    %(image_interp_topomap)s
    %(extrapolate_topomap)s

        .. versionchanged:: 0.21

           - The default was changed to ``'local'`` for MEG sensors.
           - ``'local'`` was changed to use a convex hull mask
           - ``'head'`` was changed to extrapolate out to the clipping circle.
    %(border_topomap)s

        .. versionadded:: 0.20
    %(res_topomap)s
    %(size_topomap)s
    %(cmap_topomap)s
    %(vlim_plot_topomap)s
    %(cnorm)s
    %(axes_evoked_plot_topomap)s
    %(colorbar_topomap)s
    %(cbar_fmt_topomap)s
    %(title_none)s
    %(show)s

    Returns
    -------
    fig : instance of matplotlib.figure.Figure
        Figure with a topomap subplot for each channel type.

    Notes
    -----
    .. versionadded:: 1.2
    """
    import matplotlib
    import matplotlib.pyplot as plt

    from ..channels.layout import _merge_ch_data

    sphere = _check_sphere(sphere)
    if ch_type is None:
        ch_types = model.info_.get_channel_types(unique=True, only_data_chs=True)
    else:
        ch_types = [ch_type]
    del ch_type

    nrows = model.coef_.shape[1]
    ncols = len(ch_types)

    axes_was_none = axes is None
    if axes_was_none:
        fig, axes = plt.subplots(
            nrows,
            ncols,
            squeeze=False,
            figsize=(ncols * 2, nrows * 1.5 + 1),
            layout="constrained",
        )
        axes = axes.T.ravel()
    else:
        if isinstance(axes, matplotlib.axes.Axes):
            axes = [axes]
        fig = axes[0].get_figure()
    if len(axes) != nrows * ncols:
        raise ValueError(
            f"axes must be a list of {nrows * ncols} axes, got "
            f"length {len(axes)} ({axes})."
        )
    axes = iter(axes)

    data_picks = _picks_to_idx(model.info_, model.picks, exclude=model.exclude)
    data_info = pick_info(model.info_, data_picks)
    artifact_ch_names = [
        model.info_["chs"][idx]["ch_name"]
        for idx in _picks_to_idx(model.info_, model.picks_artifact)
    ]

    for ch_type in ch_types:
        (
            data_picks,
            pos,
            merge_channels,
            names,
            ch_type,
            sphere,
            clip_origin,
        ) = _prepare_topomap_plot(data_info, ch_type=ch_type, sphere=sphere)
        outlines = _make_head_outlines(
            sphere, pos, outlines=outlines, clip_origin=clip_origin
        )
        coef = model.coef_[data_picks]
        for data, ch_name in zip(coef.T, artifact_ch_names):
            if merge_channels:
                data, names = _merge_ch_data(data, ch_type, names)
            ax = next(axes)
            names = _prepare_sensor_names(data_info.ch_names, show_names)

            _plot_topomap_multi_cbar(
                data,
                pos,
                ax,
                title=f"{ch_type}/{ch_name}",
                vlim=vlim,
                cmap=cmap,
                outlines=outlines,
                colorbar=colorbar,
                unit="",
                cbar_fmt=cbar_fmt,
                sphere=sphere,
                ch_type=ch_type,
                sensors=sensors,
                names=names,
                mask=mask,
                mask_params=mask_params,
                contours=contours,
                image_interp=image_interp,
                extrapolate=extrapolate,
                border=border,
                res=res,
                size=size,
                cnorm=cnorm,
            )
    if axes_was_none:
        fig.suptitle(title)
    plt_show(show)
    return fig
